Merge branch 'Comfy-Org:master' into qwen-image-vae

This commit is contained in:
brucew4yn3rp 2026-06-24 09:36:00 -04:00 committed by GitHub
commit ca686606ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
191 changed files with 34647 additions and 1171 deletions

View File

@ -1,5 +1,4 @@
As of the time of writing this you need this driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
As of the time of writing this you need a recent driver. Updating to the latest driver is recommended.
HOW TO RUN:
@ -7,9 +6,9 @@ If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
If you have memory issues you can try enabling the new dynamic memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
run_amd_gpu_enable_dynamic_vram.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints

View File

@ -17,7 +17,7 @@ jobs:
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci')
# Flag to track if CRLF is found
CRLF_FOUND=false

View File

@ -140,7 +140,7 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
2. **[Comfy Desktop](https://github.com/Comfy-Org/Comfy-Desktop)**
- Builds a new release using the latest stable core version
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
@ -309,7 +309,7 @@ After this you should have everything installed and can proceed to running Comfy
#### Apple Mac silicon
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
You can install ComfyUI in Apple Mac silicon (M1, M2, M3 or M4) with any recent macOS version.
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
@ -364,7 +364,7 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
| Flag | Description |
|------|-------------|
| `--enable-manager` | Enable ComfyUI-Manager |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (implies `--enable-manager`) |
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
@ -382,11 +382,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
### AMD ROCm Tips
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
You can try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
# Notes
@ -462,16 +458,6 @@ To use the most up-to-date frontend version:
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
### Accessing the Legacy Frontend
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
```
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
```
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
# QA
### Which GPU should I buy for this?

View File

@ -0,0 +1,39 @@
"""
Drop the vestigial tags.tag_type column.
tag_type was always "user" in practice no code path ever set it to anything
else (no system/seeded classification was ever wired up) and nothing queried it.
The column, its index (ix_tags_tag_type), and the corresponding API field were
dead weight, so they are removed.
Revision ID: 0004_drop_tag_type
Revises: 0003_add_metadata_job_id
Create Date: 2026-06-03
"""
from alembic import op
import sqlalchemy as sa
revision = "0004_drop_tag_type"
down_revision = "0003_add_metadata_job_id"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("tags") as batch_op:
batch_op.drop_index("ix_tags_tag_type")
batch_op.drop_column("tag_type")
def downgrade() -> None:
with op.batch_alter_table("tags") as batch_op:
batch_op.add_column(
sa.Column(
"tag_type",
sa.String(length=32),
nullable=False,
server_default="user",
)
)
batch_op.create_index("ix_tags_tag_type", ["tag_type"])

View File

@ -39,6 +39,7 @@ from app.assets.services import (
update_asset_metadata,
upload_from_temp_path,
)
from app.assets.services.cursor import InvalidCursorError
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef()
@ -174,7 +175,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
job_id=result.ref.job_id,
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
@ -211,24 +212,37 @@ async def list_assets_route(request: web.Request) -> web.Response:
order_candidate = (q.order or "desc").lower()
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=sort,
order=order,
)
try:
result = list_assets_page(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
offset=q.offset,
sort=sort,
order=order,
after=q.after,
)
except InvalidCursorError as e:
return _build_error_response(400, "INVALID_CURSOR", str(e))
summaries = [_build_asset_response(item) for item in result.items]
# has_more semantics differ by mode:
# - cursor mode: a non-empty next_cursor means there are more results.
# - offset mode: derived from total - (offset + page size).
if q.after is not None:
has_more = result.next_cursor is not None
else:
has_more = (q.offset + len(summaries)) < result.total
payload = schemas_out.AssetsList(
assets=summaries,
total=result.total,
has_more=(q.offset + len(summaries)) < result.total,
has_more=has_more,
next_cursor=result.next_cursor,
)
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ -519,18 +533,14 @@ async def update_asset_route(request: web.Request) -> web.Response:
@_require_assets_feature_enabled
async def delete_asset_route(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
delete_content_param = request.query.get("delete_content")
delete_content = (
False
if delete_content_param is None
else delete_content_param.lower() not in {"0", "false", "no"}
)
try:
# Deleting an asset is a soft delete of the reference; the underlying
# content is preserved (it may be shared with other references).
deleted = delete_asset_reference(
reference_id=reference_id,
owner_id=USER_MANAGER.get_request_user_id(request),
delete_content_if_orphan=delete_content,
delete_content_if_orphan=False,
)
except Exception:
logging.exception(
@ -575,8 +585,8 @@ async def get_tags(request: web.Request) -> web.Response:
)
tags = [
schemas_out.TagUsage(name=name, count=count, type=tag_type)
for (name, tag_type, count) in rows
schemas_out.TagUsage(name=name, count=count)
for (name, count) in rows
]
payload = schemas_out.TagsList(
tags=tags, total=total, has_more=(query.offset + len(tags)) < total

View File

@ -59,6 +59,11 @@ class ListAssetsQuery(BaseModel):
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
# Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination
# is supported for sort values `created_at`, `updated_at`, `name`, `size`.
# Supplying `after` together with `sort=last_access_time` returns
# 400 INVALID_CURSOR; that sort only supports offset/limit.
after: str | None = None
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"

View File

@ -41,12 +41,13 @@ class AssetsList(BaseModel):
assets: list[Asset]
total: int
has_more: bool
# Opaque cursor for the next page. Omitted when there are no more results.
next_cursor: str | None = None
class TagUsage(BaseModel):
name: str
count: int
type: str
class TagsList(BaseModel):

View File

@ -227,7 +227,6 @@ class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="tag",
@ -240,7 +239,5 @@ class Tag(Base):
overlaps="asset_reference_links,tag_links,tags,asset_reference",
)
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@ -266,9 +266,18 @@ def list_references_page(
metadata_filter: dict | None = None,
sort: str | None = None,
order: str | None = None,
after_cursor_value: object | None = None,
after_cursor_id: str | None = None,
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
"""List references with pagination, filtering, and sorting.
When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses
keyset pagination ``offset`` is ignored and a WHERE clause selects rows
strictly after the given ``(sort_col, id)`` position in the active sort
direction. The cursor value must already be typed for the column
(datetime for time sorts, int for size, str for name); the caller decodes
the opaque cursor string and resolves to the typed value.
Returns (references, tag_map, total_count).
"""
base = (
@ -297,9 +306,31 @@ def list_references_page(
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetReference.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
descending = order == "desc"
base = base.order_by(sort_exp).limit(limit).offset(offset)
# Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor.
# Equivalent to: sort_col <op> v OR (sort_col = v AND id <op> cursor_id).
if after_cursor_value is not None and after_cursor_id is not None:
if descending:
keyset = sa.or_(
sort_col < after_cursor_value,
sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id),
)
else:
keyset = sa.or_(
sort_col > after_cursor_value,
sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id),
)
base = base.where(keyset)
# Secondary ORDER BY id (matching the primary direction) gives the keyset
# comparison a deterministic tiebreaker on duplicate sort_col values.
id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc()
sort_exp = sort_col.desc() if descending else sort_col.asc()
base = base.order_by(sort_exp, id_exp).limit(limit)
if after_cursor_id is None:
base = base.offset(offset)
count_stmt = (
select(sa.func.count())

View File

@ -55,13 +55,11 @@ def validate_tags_exist(session: Session, tags: list[str]) -> None:
raise ValueError(f"Unknown tags: {missing}")
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:
def ensure_tags_exist(session: Session, names: Iterable[str]) -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
rows = [{"name": n} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
@ -97,7 +95,7 @@ def set_reference_tags(
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
ensure_tags_exist(session, to_add)
session.add_all(
[
AssetReferenceTag(
@ -142,7 +140,7 @@ def add_tags_to_reference(
return AddTagsResult(added=[], already_present=[], total_tags=total)
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
ensure_tags_exist(session, norm)
current = set(get_reference_tags(session, reference_id))
@ -289,7 +287,6 @@ def list_tags_with_usage(
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
@ -331,7 +328,7 @@ def list_tags_with_usage(
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
rows_norm = [(name, int(count or 0)) for (name, count) in rows]
return rows_norm, int(total or 0)

View File

@ -33,6 +33,7 @@ from app.assets.services.file_utils import (
verify_file_unchanged,
)
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
@ -354,7 +355,7 @@ def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
return 0
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
ensure_tags_exist(sess, tag_pool)
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
sess.commit()
return result.inserted_refs
@ -506,6 +507,10 @@ def enrich_asset(
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
if mime_type and mime_type.startswith("image/"):
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if dims:
system_metadata.update(dims)
set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash:

View File

@ -1,8 +1,19 @@
import contextlib
import mimetypes
import os
from datetime import timezone
from typing import Sequence
from app.assets.services.cursor import (
CursorPayload,
InvalidCursorError,
decode_cursor,
decode_cursor_int,
decode_cursor_time,
encode_cursor,
encode_cursor_from_time,
)
from app.assets.database.models import Asset
from app.assets.database.queries import (
@ -149,6 +160,16 @@ def delete_asset_reference(
owner_id: str,
delete_content_if_orphan: bool = True,
) -> bool:
"""Delete an asset reference.
With ``delete_content_if_orphan=False`` (a soft delete), the reference is
hidden and the underlying content is preserved. With ``True``, the content
is also removed once it becomes orphaned.
Note: the public DELETE /api/assets/{id} endpoint always soft-deletes
(passes ``False``); the orphan-reclamation path is intentionally
internal-only, retained for a future GC/admin caller.
"""
with create_session() as session:
if not delete_content_if_orphan:
# Soft delete: mark the reference as deleted but keep everything
@ -242,6 +263,11 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None:
return extract_asset_data(asset)
# Sort fields that support cursor pagination. `last_access_time` is not
# in this list — it falls back to offset/limit.
_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size")
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
@ -252,7 +278,39 @@ def list_assets_page(
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
after: str | None = None,
) -> ListAssetsResult:
"""List assets with optional cursor pagination.
When ``after`` is supplied it overrides ``offset``. The cursor's sort field
must match ``sort`` and be in the cursor-supported allowlist; mismatches
raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR.
"""
cursor_value: object | None = None
cursor_id: str | None = None
# Mint next_cursor on every page where the sort is cursor-supported, not
# only when the request itself arrived with a cursor. Otherwise a first
# request (no `after`) returns next_cursor=None and the client can never
# enter cursor mode.
mint_cursor = sort in _CURSOR_SORT_FIELDS
if after is not None:
if sort not in _CURSOR_SORT_FIELDS:
raise InvalidCursorError(
f"cursor pagination is not supported for sort={sort!r}"
)
payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order)
if payload.sort_field != sort:
raise InvalidCursorError(
f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}"
)
cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id
# Over-fetch by one row so we can distinguish "exactly `limit` rows total
# remaining" from "more rows past this page" without a second query. Drop
# the sentinel before returning.
fetch_limit = limit + 1 if mint_cursor else limit
with create_session() as session:
refs, tag_map, total = list_references_page(
session,
@ -261,12 +319,22 @@ def list_assets_page(
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
limit=fetch_limit,
offset=offset,
sort=sort,
order=order,
after_cursor_value=cursor_value,
after_cursor_id=cursor_id,
)
next_cursor: str | None = None
if mint_cursor and len(refs) > limit:
# There's at least one more row past this page — mint a cursor from
# the last row of the page (i.e. index `limit - 1`, since we
# over-fetched), and drop the sentinel.
next_cursor = _encode_next_cursor(refs[limit - 1], sort, order)
refs = refs[:limit]
items: list[AssetSummaryData] = []
for ref in refs:
items.append(
@ -277,7 +345,39 @@ def list_assets_page(
)
)
return ListAssetsResult(items=items, total=total)
return ListAssetsResult(items=items, total=total, next_cursor=next_cursor)
def _resolve_cursor_value(payload: CursorPayload) -> object:
"""Map a decoded cursor payload to a column-typed Python value."""
if payload.sort_field in ("created_at", "updated_at"):
# DB stores naive UTC; strip tzinfo so the comparison binds against a
# `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift.
return decode_cursor_time(payload).replace(tzinfo=None)
if payload.sort_field == "size":
return decode_cursor_int(payload)
return payload.value # name, str-typed
def _encode_next_cursor(ref, sort: str, order: str) -> str | None:
"""Mint a cursor pointing at *ref* for the given sort dimension.
Returns None when the boundary row carries a NULL sort value (e.g. an asset
record whose size_bytes hasn't been backfilled). Continuing pagination
across a NULL boundary is undefined under keyset ordering better to
truncate cleanly here than to mint a cursor that mis-positions.
"""
if sort == "name":
return encode_cursor("name", ref.name, ref.id, order=order)
if sort == "size":
if ref.asset is None or ref.asset.size_bytes is None:
return None
return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order)
# created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding.
value = ref.created_at if sort == "created_at" else ref.updated_at
if value is None:
return None
return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order)
def resolve_hash_to_path(

View File

@ -0,0 +1,213 @@
"""Opaque keyset-pagination cursor for /api/assets.
Payload JSON uses short keys to keep the encoded length small:
{"s": <sort_field>, "v": <value>, "id": <id>, "o": <order>}
The `o` key binds the cursor to the sort direction it was minted under,
so replaying a `desc` cursor against an `asc` request fails with
``INVALID_CURSOR`` rather than silently walking the wrong direction.
`o` is mandatory on every payload a cursor without it is rejected as
malformed.
Encoding is base64url with no padding. Cursors are opaque tokens: the
payload format is internal to this server, and clients must treat a
cursor as a black box handed back via `next_cursor`. No byte-level
compatibility with any other implementation is required.
Time values are serialized as Unix microseconds (UTC) microsecond
precision is sufficient to round-trip the timestamps stored by the
database without rounding rows in the same millisecond bucket.
"""
from __future__ import annotations
import base64
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Iterable, Optional
class InvalidCursorError(ValueError):
"""Raised on a malformed, oversized, or unsupported-sort-field cursor.
Map to a 400 response with code ``INVALID_CURSOR`` at the handler.
"""
# Wire-format length caps. Cursors are user-controlled, so caps protect the
# decode path from oversized allocations and downstream SQL predicates from
# unbounded strings.
#
# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max
# (`String(512)`) — otherwise a long-named asset would mint a cursor the same
# server then refuses on the next request.
#
# MAX_ENCODED_CURSOR_LENGTH is the decode-path guard, sized comfortably above
# the largest cursor the per-field caps can produce. Worst case is value + id
# at their caps with every character JSON-escaping to the six-byte `\uXXXX`
# form (control characters), which is ~5.2 KB once base64url-encoded. At 8192
# the encoder can never mint a cursor that exceeds it, so a freshly minted
# cursor always decodes on the next request and there is no user-visible
# "cursor too long" failure.
MAX_ENCODED_CURSOR_LENGTH = 8192
MAX_CURSOR_VALUE_LENGTH = 512
MAX_CURSOR_ID_LENGTH = 128
@dataclass(frozen=True)
class CursorPayload:
sort_field: str
value: str
id: str
order: str
_VALID_ORDERS = ("asc", "desc")
def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str:
"""Encode a cursor payload as a base64url (no-padding) string.
`order` binds the cursor to the sort direction it was minted under so a
later request with a flipped `order` query parameter is rejected with
``INVALID_CURSOR`` rather than silently walking the wrong direction.
"""
if order not in _VALID_ORDERS:
raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}")
# Symmetric input validation: the encoder must reject anything the
# decoder rejects, or the same server will mint cursors it then 400s on
# the next request.
if not id:
raise InvalidCursorError("id must be non-empty")
if len(id) > MAX_CURSOR_ID_LENGTH:
raise InvalidCursorError("id exceeds maximum length")
if len(value) > MAX_CURSOR_VALUE_LENGTH:
raise InvalidCursorError("value exceeds maximum length")
payload = {"s": sort_field, "v": value, "id": id, "o": order}
raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
# No mint-time length guard is needed: the per-field caps above bound the
# encoded length well below MAX_ENCODED_CURSOR_LENGTH (see its definition),
# so the encoder can never produce a cursor the decode path would reject.
return base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii")
def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str:
"""Encode a time-typed cursor at Unix microsecond precision.
Accepts an aware datetime (any timezone) and normalizes to UTC. Naive
datetimes are rejected so callers can't accidentally encode the local
wall-clock value of a UTC-stored timestamp.
"""
if t.tzinfo is None:
raise ValueError("encode_cursor_from_time requires an aware datetime")
micros = _datetime_to_unix_micros(t.astimezone(timezone.utc))
return encode_cursor(sort_field, str(micros), id, order=order)
def decode_cursor(
cursor: str,
allowed_sort_fields: Iterable[str],
expected_order: str | None = None,
) -> CursorPayload:
"""Parse an opaque cursor.
``allowed_sort_fields`` is the endpoint's accepted sort-field list — a
cursor carrying a field outside this set is rejected so a cursor minted
for one column can't be replayed against another (e.g. a ``created_at``
timestamp string compared against a ``name`` column).
``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the
payload's ``o`` field. ``o`` is required on every payload; a cursor
missing it is rejected as malformed.
Passing no allowed fields rejects every cursor.
"""
if len(cursor) > MAX_ENCODED_CURSOR_LENGTH:
raise InvalidCursorError("cursor exceeds maximum length")
try:
# urlsafe_b64decode requires correct padding; we strip on encode, so
# restore the trailing '=' pad here.
padding = "=" * (-len(cursor) % 4)
raw = base64.urlsafe_b64decode(cursor + padding)
except (ValueError, base64.binascii.Error) as e:
raise InvalidCursorError(f"encoding: {e}") from e
try:
decoded = json.loads(raw)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise InvalidCursorError(f"payload: {e}") from e
if not isinstance(decoded, dict):
raise InvalidCursorError("payload: expected object")
sort_field = decoded.get("s")
value = decoded.get("v")
id = decoded.get("id")
order = decoded.get("o")
if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str):
raise InvalidCursorError("payload: missing or non-string s/v/id")
if id == "":
raise InvalidCursorError("missing id")
if len(id) > MAX_CURSOR_ID_LENGTH:
raise InvalidCursorError("id exceeds maximum length")
if len(value) > MAX_CURSOR_VALUE_LENGTH:
raise InvalidCursorError("value exceeds maximum length")
if sort_field not in allowed_sort_fields:
raise InvalidCursorError(f"unsupported sort field {sort_field!r}")
if not isinstance(order, str):
raise InvalidCursorError("missing or non-string o")
if order not in _VALID_ORDERS:
raise InvalidCursorError(f"unsupported order {order!r}")
if expected_order is not None and order != expected_order:
raise InvalidCursorError(
f"cursor order {order!r} does not match request order {expected_order!r}"
)
return CursorPayload(sort_field=sort_field, value=value, id=id, order=order)
def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime:
"""Parse a time-typed cursor value as Unix microseconds, returning UTC."""
if payload is None:
raise InvalidCursorError("nil cursor payload")
try:
micros = int(payload.value)
except ValueError as e:
raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e
try:
return _unix_micros_to_datetime(micros)
except (OverflowError, OSError, ValueError) as e:
# Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up
# in fromtimestamp / datetime construction. Map to 400, not 500.
raise InvalidCursorError(f"value is out of representable range: {e}") from e
def decode_cursor_int(payload: Optional[CursorPayload]) -> int:
"""Parse a cursor value as a base-10 integer."""
if payload is None:
raise InvalidCursorError("nil cursor payload")
try:
return int(payload.value)
except ValueError as e:
raise InvalidCursorError(f"value is not a valid integer: {e}") from e
_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
def _datetime_to_unix_micros(t: datetime) -> int:
"""Convert an aware UTC datetime to Unix microseconds (integer math)."""
delta = t - _EPOCH
return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds
def _unix_micros_to_datetime(micros: int) -> datetime:
"""Convert Unix microseconds to a UTC datetime, preserving precision."""
seconds, micro_remainder = divmod(micros, 1_000_000)
return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder)

View File

@ -0,0 +1,63 @@
"""Image dimension extraction for asset ingest.
Reads only the image header via Pillow to capture width/height cheaply,
without a full pixel decode. Returns a metadata dict suitable for merging
into ``AssetReference.system_metadata``.
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
def extract_image_dimensions(
file_path: str, mime_type: str | None = None
) -> dict[str, Any] | None:
"""Extract image dimensions for the file at ``file_path``.
Args:
file_path: Absolute path to a file on disk.
mime_type: Optional MIME type hint. When provided and not prefixed
with ``image/``, extraction is skipped without touching the file.
Returns:
``{"kind": "image", "width": W, "height": H}`` when the file is a
recognizable image with positive dimensions, otherwise ``None``.
The dict shape is intended to be merged into ``system_metadata`` so the
asset response surfaces ``metadata.kind`` plus dimension fields for image
assets. Forward-compatible: future media kinds (e.g. ``"video"`` with
duration/fps) can extend this shape without schema changes.
"""
if mime_type is not None and not mime_type.startswith("image/"):
return None
try:
from PIL import Image, UnidentifiedImageError
except ImportError:
logger.debug(
"Pillow not available; skipping image dimension extraction for %s",
file_path,
)
return None
try:
with Image.open(file_path) as img:
width, height = img.size
except (OSError, UnidentifiedImageError, ValueError) as exc:
logger.debug(
"Failed to read image dimensions from %s: %s", file_path, exc
)
return None
if (
not isinstance(width, int)
or not isinstance(height, int)
or width <= 0
or height <= 0
):
return None
return {"kind": "image", "width": width, "height": height}

View File

@ -17,9 +17,11 @@ from app.assets.database.queries import (
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
list_references_by_asset_id,
reference_exists,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_system_metadata,
set_reference_tags,
update_asset_hash_and_mime,
upsert_asset,
@ -29,6 +31,7 @@ from app.assets.database.queries import (
from app.assets.helpers import get_utc_now, normalize_tags
from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.path_utils import (
compute_relative_filename,
get_name_and_tags_from_asset_path,
@ -118,6 +121,14 @@ def _ingest_file_from_path(
user_metadata=user_metadata,
)
_maybe_store_image_dimensions(
session,
reference_id=reference_id,
file_path=locator,
mime_type=mime_type,
current_system_metadata=ref.system_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
@ -288,6 +299,13 @@ def _register_existing_asset(
user_metadata=new_meta,
)
_backfill_image_dimensions_from_siblings(
session,
asset_id=asset.id,
new_reference_id=ref.id,
current_system_metadata=ref.system_metadata,
)
if tags is not None:
set_reference_tags(
session,
@ -334,6 +352,87 @@ def _update_metadata_with_filename(
)
_IMAGE_DIMENSION_KEYS = ("kind", "width", "height")
def _maybe_store_image_dimensions(
session: Session,
reference_id: str,
file_path: str,
mime_type: str | None,
current_system_metadata: dict | None,
) -> None:
"""Populate ``kind``/``width``/``height`` on system_metadata for image refs.
Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written
safetensors metadata, download provenance) are preserved by merge.
"""
if not mime_type or not mime_type.startswith("image/"):
return
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if not dims:
return
current = current_system_metadata or {}
merged = dict(current)
merged.update(dims)
if merged != current:
set_reference_system_metadata(
session,
reference_id=reference_id,
system_metadata=merged,
)
def _backfill_image_dimensions_from_siblings(
session: Session,
asset_id: str,
new_reference_id: str,
current_system_metadata: dict | None,
) -> None:
"""Copy image dimension keys from any sibling reference of the same asset.
The from-hash path doesn't read the file bytes, so dimensions can't be
extracted there directly. When another reference of the same asset already
carries image dimensions, copy them onto the new reference so consumers
see consistent metadata regardless of how the asset was registered.
Best-effort: missing siblings, non-image siblings, or absent dimension
keys leave the target reference unchanged.
"""
current = current_system_metadata or {}
if current.get("kind") == "image" and "width" in current and "height" in current:
return
for sibling in list_references_by_asset_id(session, asset_id):
if sibling.id == new_reference_id:
continue
meta = sibling.system_metadata or {}
if meta.get("kind") != "image":
continue
width = meta.get("width")
height = meta.get("height")
if (
type(width) is not int
or type(height) is not int
or width <= 0
or height <= 0
):
continue
merged = dict(current)
merged["kind"] = "image"
merged["width"] = width
merged["height"] = height
if merged != current:
set_reference_system_metadata(
session,
reference_id=new_reference_id,
system_metadata=merged,
)
return
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback

View File

@ -56,7 +56,6 @@ class IngestResult:
class TagUsage(NamedTuple):
name: str
tag_type: str
count: int
@ -71,6 +70,7 @@ class AssetSummaryData:
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
next_cursor: str | None = None
@dataclass(frozen=True)

View File

@ -75,7 +75,7 @@ def list_tags(
owner_id=owner_id,
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
return [TagUsage(name, count) for name, count in rows], total
def list_tag_histogram(

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,569 @@
{
"revision": 0,
"last_node_id": 89,
"last_link_id": 0,
"nodes": [
{
"id": 89,
"type": "85e595bd-af9e-40ee-85c5-b98bb15da47a",
"pos": [
320,
520
],
"size": [
400,
360
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": null
},
{
"name": "resolution",
"type": "INT",
"widget": {
"name": "resolution"
},
"link": null
},
{
"name": "resize_method",
"type": "COMBO",
"widget": {
"name": "resize_method"
},
"link": null
},
{
"label": "output_type",
"name": "output",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "output"
},
"link": null
},
{
"label": "output_normalization",
"name": "output.normalization",
"type": "COMBO",
"widget": {
"name": "output.normalization"
},
"link": null
},
{
"label": "apply_sky_clip",
"name": "output.apply_sky_clip",
"type": "BOOLEAN",
"widget": {
"name": "output.apply_sky_clip"
},
"link": null
},
{
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"87",
"resolution"
],
[
"87",
"resize_method"
],
[
"86",
"output"
],
[
"86",
"output.normalization"
],
[
"86",
"output.apply_sky_clip"
],
[
"88",
"model_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.24.0"
},
"widgets_values": [],
"title": "Image Depth Estimation (Depth Anything 3)"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "85e595bd-af9e-40ee-85c5-b98bb15da47a",
"version": 1,
"state": {
"lastGroupId": 4,
"lastNodeId": 89,
"lastLinkId": 109,
"lastRerouteId": 0
},
"revision": 2,
"config": {},
"name": "Image Depth Estimation (Depth Anything 3)",
"inputNode": {
"id": -10,
"bounding": [
400,
90,
166.998046875,
188
]
},
"outputNode": {
"id": -20,
"bounding": [
1250,
146,
128,
68
]
},
"inputs": [
{
"id": "43cf3118-495a-487d-8eb3-a17c7e92f64f",
"name": "image",
"type": "IMAGE",
"linkIds": [
19
],
"localized_name": "image",
"pos": [
542.998046875,
114
]
},
{
"id": "1089a0a1-6db1-45a8-84b0-0bfdc2ed920a",
"name": "resolution",
"type": "INT",
"linkIds": [
22
],
"pos": [
542.998046875,
134
]
},
{
"id": "25fb64ac-26d5-466d-995b-6d51b9afa2c4",
"name": "resize_method",
"type": "COMBO",
"linkIds": [
23
],
"pos": [
542.998046875,
154
]
},
{
"id": "8acafb7c-6c8b-46b3-9d74-c563498a3af1",
"name": "output",
"type": "COMFY_DYNAMICCOMBO_V3",
"linkIds": [
24
],
"label": "output_type",
"pos": [
542.998046875,
174
]
},
{
"id": "1da5009b-4648-43e8-a257-16426630cf22",
"name": "output.normalization",
"type": "COMBO",
"linkIds": [
25
],
"label": "output_normalization",
"pos": [
542.998046875,
194
]
},
{
"id": "fd7edb33-5fb1-4538-a411-26e5039a9321",
"name": "output.apply_sky_clip",
"type": "BOOLEAN",
"linkIds": [
26
],
"label": "apply_sky_clip",
"pos": [
542.998046875,
214
]
},
{
"id": "b5be4c8a-b833-4f1e-8c94-3ed1dd722190",
"name": "model_name",
"type": "COMBO",
"linkIds": [
106
],
"pos": [
542.998046875,
234
]
}
],
"outputs": [
{
"id": "478ab537-63bc-4d74-a9f0-c975f550880f",
"name": "IMAGE",
"type": "IMAGE",
"linkIds": [
7
],
"localized_name": "IMAGE",
"pos": [
1274,
170
]
}
],
"widgets": [],
"nodes": [
{
"id": 86,
"type": "DA3Render",
"pos": [
800,
310
],
"size": [
380,
130
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "da3_geometry",
"name": "da3_geometry",
"type": "DA3_GEOMETRY",
"link": 12
},
{
"localized_name": "output",
"name": "output",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "output"
},
"link": 24
},
{
"localized_name": "output.normalization",
"name": "output.normalization",
"type": "COMBO",
"widget": {
"name": "output.normalization"
},
"link": 25
},
{
"localized_name": "output.apply_sky_clip",
"name": "output.apply_sky_clip",
"type": "BOOLEAN",
"widget": {
"name": "output.apply_sky_clip"
},
"link": 26
},
{
"name": "geometry",
"type": "DA3_GEOMETRY",
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"slot_index": 0,
"links": [
7
]
}
],
"properties": {
"Node name for S&R": "DA3Render",
"cnr_id": "comfy-core",
"ver": "0.19.0"
},
"widgets_values": [
"depth",
"v2_style",
false
]
},
{
"id": 87,
"type": "DA3Inference",
"pos": [
800,
50
],
"size": [
390,
130
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "da3_model",
"name": "da3_model",
"type": "DA3_MODEL",
"link": 107
},
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 19
},
{
"localized_name": "resolution",
"name": "resolution",
"type": "INT",
"widget": {
"name": "resolution"
},
"link": 22
},
{
"localized_name": "resize_method",
"name": "resize_method",
"type": "COMBO",
"widget": {
"name": "resize_method"
},
"link": 23
},
{
"localized_name": "mode",
"name": "mode",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "mode"
},
"link": null
}
],
"outputs": [
{
"localized_name": "da3_geometry",
"name": "da3_geometry",
"type": "DA3_GEOMETRY",
"slot_index": 0,
"links": [
12
]
}
],
"properties": {
"Node name for S&R": "DA3Inference",
"cnr_id": "comfy-core",
"ver": "0.19.0"
},
"widgets_values": [
504,
"upper_bound_resize",
"mono"
]
},
{
"id": 88,
"type": "LoadDA3Model",
"pos": [
810,
-160
],
"size": [
400,
140
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "model_name",
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": 106
},
{
"localized_name": "weight_dtype",
"name": "weight_dtype",
"type": "COMBO",
"widget": {
"name": "weight_dtype"
},
"link": null
}
],
"outputs": [
{
"localized_name": "DA3_MODEL",
"name": "DA3_MODEL",
"type": "DA3_MODEL",
"links": [
107
]
}
],
"properties": {
"Node name for S&R": "LoadDA3Model",
"cnr_id": "comfy-core",
"ver": "0.24.0",
"models": [
{
"name": "depth_anything_3_mono_large.safetensors",
"url": "https://huggingface.co/Comfy-Org/Depth-Anything-3/resolve/main/geometry_estimation/depth_anything_3_mono_large.safetensors",
"directory": "geometry_estimation"
}
]
},
"widgets_values": [
"depth_anything_3_mono_large.safetensors",
"default"
]
}
],
"groups": [],
"links": [
{
"id": 12,
"origin_id": 87,
"origin_slot": 0,
"target_id": 86,
"target_slot": 0,
"type": "DA3_GEOMETRY"
},
{
"id": 19,
"origin_id": -10,
"origin_slot": 0,
"target_id": 87,
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 7,
"origin_id": 86,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 22,
"origin_id": -10,
"origin_slot": 1,
"target_id": 87,
"target_slot": 2,
"type": "INT"
},
{
"id": 23,
"origin_id": -10,
"origin_slot": 2,
"target_id": 87,
"target_slot": 3,
"type": "COMBO"
},
{
"id": 24,
"origin_id": -10,
"origin_slot": 3,
"target_id": 86,
"target_slot": 1,
"type": "COMFY_DYNAMICCOMBO_V3"
},
{
"id": 25,
"origin_id": -10,
"origin_slot": 4,
"target_id": 86,
"target_slot": 2,
"type": "COMBO"
},
{
"id": 26,
"origin_id": -10,
"origin_slot": 5,
"target_id": 86,
"target_slot": 3,
"type": "BOOLEAN"
},
{
"id": 106,
"origin_id": -10,
"origin_slot": 6,
"target_id": 88,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 107,
"origin_id": 88,
"origin_slot": 0,
"target_id": 87,
"target_slot": 0,
"type": "DA3_MODEL"
}
],
"extra": {},
"category": "Conditioning & Preprocessors/Depth",
"description": "This subgraph takes an input image and produces a depth map using the Depth Anything 3 model, which recovers spatially consistent geometry from any number of views. It is ideal for single or multi-view images, videos, and 3D scenes where accurate depth estimation is needed for tasks like SLAM, novel view synthesis, or spatial perception. The model uses a plain transformer backbone and supports both monocular and multi-view inputs without."
}
]
},
"extra": {
"BlueprintDescription": "This subgraph takes an input image and produces a depth map using the Depth Anything 3 model, which recovers spatially consistent geometry from any number of views. It is ideal for single or multi-view images, videos, and 3D scenes where accurate depth estimation is needed for tasks like SLAM, novel view synthesis, or spatial perception. The model uses a plain transformer backbone and supports both monocular and multi-view inputs without."
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1077,9 +1077,12 @@
}
],
"extra": {},
"category": "Image generation and editing/Text to image"
"category": "Image generation and editing/Text to image",
"description": "This subgraph converts text prompts into non-photorealistic illustrations using a 2-billion-parameter model optimized for anime and artistic styles. It is ideal for generating concept art, character designs, or stylized illustrations where photorealism is not required. The model excels with anime and artistic content but performs poorly on realistic subjects."
}
]
},
"extra": {}
"extra": {
"BlueprintDescription": "This subgraph converts text prompts into non-photorealistic illustrations using a 2-billion-parameter model optimized for anime and artistic styles. It is ideal for generating concept art, character designs, or stylized illustrations where photorealism is not required. The model excels with anime and artistic content but performs poorly on realistic subjects."
}
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,825 @@
{
"revision": 0,
"last_node_id": 97,
"last_link_id": 0,
"nodes": [
{
"id": 97,
"type": "253ec5ca-8333-4ddf-a036-9fc0923651b9",
"pos": [
410,
500
],
"size": [
400,
400
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "video",
"type": "VIDEO",
"link": null
},
{
"name": "start_time",
"type": "FLOAT",
"widget": {
"name": "start_time"
},
"link": null
},
{
"name": "duration",
"type": "FLOAT",
"widget": {
"name": "duration"
},
"link": null
},
{
"name": "resolution",
"type": "INT",
"widget": {
"name": "resolution"
},
"link": null
},
{
"name": "resize_method",
"type": "COMBO",
"widget": {
"name": "resize_method"
},
"link": null
},
{
"label": "output_type",
"name": "output",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "output"
},
"link": null
},
{
"label": "normalization",
"name": "output.normalization",
"type": "COMBO",
"widget": {
"name": "output.normalization"
},
"link": null
},
{
"name": "output.apply_sky_clip",
"type": "BOOLEAN",
"widget": {
"name": "output.apply_sky_clip"
},
"link": null
},
{
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"links": []
},
{
"name": "audio",
"type": "AUDIO",
"links": []
},
{
"name": "fps",
"type": "FLOAT",
"links": []
}
],
"properties": {
"proxyWidgets": [
[
"96",
"start_time"
],
[
"96",
"duration"
],
[
"93",
"resolution"
],
[
"93",
"resize_method"
],
[
"92",
"output"
],
[
"92",
"output.normalization"
],
[
"92",
"output.apply_sky_clip"
],
[
"94",
"model_name"
]
],
"cnr_id": "comfy-core",
"ver": "0.24.0"
},
"widgets_values": [],
"title": "Video Depth Estimation (Depth Anything 3)"
}
],
"links": [],
"version": 0.4,
"definitions": {
"subgraphs": [
{
"id": "253ec5ca-8333-4ddf-a036-9fc0923651b9",
"version": 1,
"state": {
"lastGroupId": 4,
"lastNodeId": 97,
"lastLinkId": 129,
"lastRerouteId": 0
},
"revision": 2,
"config": {},
"name": "Video Depth Estimation (Depth Anything 3)",
"inputNode": {
"id": -10,
"bounding": [
-230,
130,
167.912109375,
228
]
},
"outputNode": {
"id": -20,
"bounding": [
1520,
140,
128,
108
]
},
"inputs": [
{
"id": "698c28c6-cf92-4039-8b39-f3062868ea7c",
"name": "video",
"type": "VIDEO",
"linkIds": [
119
],
"pos": [
-86.087890625,
154
]
},
{
"id": "97a1f63e-1585-4a40-9dec-e2700120d84a",
"name": "start_time",
"type": "FLOAT",
"linkIds": [
121
],
"pos": [
-86.087890625,
174
]
},
{
"id": "4dbbd3b3-c5ee-4a56-a0d3-3268d3b2fd64",
"name": "duration",
"type": "FLOAT",
"linkIds": [
122
],
"pos": [
-86.087890625,
194
]
},
{
"id": "16f55101-f99d-4c0c-bebf-c3b31c54f13e",
"name": "resolution",
"type": "INT",
"linkIds": [
124
],
"pos": [
-86.087890625,
214
]
},
{
"id": "d9cd7693-4bb3-4ed7-9a75-276b997abcd9",
"name": "resize_method",
"type": "COMBO",
"linkIds": [
125
],
"pos": [
-86.087890625,
234
]
},
{
"id": "a6e90532-323b-462e-ba9c-1672384d5b31",
"name": "output",
"type": "COMFY_DYNAMICCOMBO_V3",
"linkIds": [
126
],
"label": "output_type",
"pos": [
-86.087890625,
254
]
},
{
"id": "69e6aeef-437d-4fde-b2fc-d5ab9369238d",
"name": "output.normalization",
"type": "COMBO",
"linkIds": [
127
],
"label": "normalization",
"pos": [
-86.087890625,
274
]
},
{
"id": "73206f72-f89a-4698-885e-5d9277df2998",
"name": "output.apply_sky_clip",
"type": "BOOLEAN",
"linkIds": [
128
],
"pos": [
-86.087890625,
294
]
},
{
"id": "dddbc7fc-9431-448a-9ed3-9aa62404288b",
"name": "model_name",
"type": "COMBO",
"linkIds": [
129
],
"pos": [
-86.087890625,
314
]
}
],
"outputs": [
{
"id": "478ab537-63bc-4d74-a9f0-c975f550880f",
"name": "IMAGE",
"type": "IMAGE",
"linkIds": [
7
],
"localized_name": "IMAGE",
"pos": [
1544,
164
]
},
{
"id": "cdaf037e-79bc-4a94-b06c-0fd32e76f615",
"name": "audio",
"type": "AUDIO",
"linkIds": [
112
],
"pos": [
1544,
184
]
},
{
"id": "4c0e5484-d193-49c7-b107-92619628880a",
"name": "fps",
"type": "FLOAT",
"linkIds": [
113
],
"pos": [
1544,
204
]
}
],
"widgets": [],
"nodes": [
{
"id": 92,
"type": "DA3Render",
"pos": [
740,
230
],
"size": [
380,
130
],
"flags": {},
"order": 0,
"mode": 0,
"inputs": [
{
"localized_name": "da3_geometry",
"name": "da3_geometry",
"type": "DA3_GEOMETRY",
"link": 12
},
{
"localized_name": "output",
"name": "output",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "output"
},
"link": 126
},
{
"localized_name": "output.normalization",
"name": "output.normalization",
"type": "COMBO",
"widget": {
"name": "output.normalization"
},
"link": 127
},
{
"localized_name": "output.apply_sky_clip",
"name": "output.apply_sky_clip",
"type": "BOOLEAN",
"widget": {
"name": "output.apply_sky_clip"
},
"link": 128
},
{
"name": "geometry",
"type": "DA3_GEOMETRY",
"link": null
}
],
"outputs": [
{
"localized_name": "IMAGE",
"name": "IMAGE",
"type": "IMAGE",
"slot_index": 0,
"links": [
7
]
}
],
"properties": {
"Node name for S&R": "DA3Render",
"cnr_id": "comfy-core",
"ver": "0.19.0"
},
"widgets_values": [
"depth",
"v2_style",
false
]
},
{
"id": 93,
"type": "DA3Inference",
"pos": [
740,
-30
],
"size": [
390,
130
],
"flags": {},
"order": 1,
"mode": 0,
"inputs": [
{
"localized_name": "da3_model",
"name": "da3_model",
"type": "DA3_MODEL",
"link": 107
},
{
"localized_name": "image",
"name": "image",
"type": "IMAGE",
"link": 111
},
{
"localized_name": "resolution",
"name": "resolution",
"type": "INT",
"widget": {
"name": "resolution"
},
"link": 124
},
{
"localized_name": "resize_method",
"name": "resize_method",
"type": "COMBO",
"widget": {
"name": "resize_method"
},
"link": 125
},
{
"localized_name": "mode",
"name": "mode",
"type": "COMFY_DYNAMICCOMBO_V3",
"widget": {
"name": "mode"
},
"link": null
}
],
"outputs": [
{
"localized_name": "da3_geometry",
"name": "da3_geometry",
"type": "DA3_GEOMETRY",
"slot_index": 0,
"links": [
12
]
}
],
"properties": {
"Node name for S&R": "DA3Inference",
"cnr_id": "comfy-core",
"ver": "0.19.0"
},
"widgets_values": [
504,
"lower_bound_resize",
"mono"
]
},
{
"id": 94,
"type": "LoadDA3Model",
"pos": [
50,
410
],
"size": [
400,
140
],
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"localized_name": "model_name",
"name": "model_name",
"type": "COMBO",
"widget": {
"name": "model_name"
},
"link": 129
},
{
"localized_name": "weight_dtype",
"name": "weight_dtype",
"type": "COMBO",
"widget": {
"name": "weight_dtype"
},
"link": null
}
],
"outputs": [
{
"localized_name": "DA3_MODEL",
"name": "DA3_MODEL",
"type": "DA3_MODEL",
"links": [
107
]
}
],
"properties": {
"Node name for S&R": "LoadDA3Model",
"cnr_id": "comfy-core",
"ver": "0.24.0",
"models": [
{
"name": "depth_anything_3_mono_large.safetensors",
"url": "https://huggingface.co/Comfy-Org/Depth-Anything-3/resolve/main/geometry_estimation/depth_anything_3_mono_large.safetensors",
"directory": "geometry_estimation"
}
]
},
"widgets_values": [
"depth_anything_3_mono_large.safetensors",
"default"
]
},
{
"id": 95,
"type": "GetVideoComponents",
"pos": [
70,
-140
],
"size": [
260,
120
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"localized_name": "video",
"name": "video",
"type": "VIDEO",
"link": 120
}
],
"outputs": [
{
"localized_name": "images",
"name": "images",
"type": "IMAGE",
"links": [
111
]
},
{
"localized_name": "audio",
"name": "audio",
"type": "AUDIO",
"links": [
112
]
},
{
"localized_name": "fps",
"name": "fps",
"type": "FLOAT",
"links": [
113
]
},
{
"localized_name": "bit_depth",
"name": "bit_depth",
"type": "INT",
"links": null
}
],
"properties": {
"Node name for S&R": "GetVideoComponents",
"cnr_id": "comfy-core",
"ver": "0.24.0"
}
},
{
"id": 96,
"type": "Video Slice",
"pos": [
70,
-360
],
"size": [
270,
170
],
"flags": {},
"order": 4,
"mode": 0,
"inputs": [
{
"localized_name": "video",
"name": "video",
"type": "VIDEO",
"link": 119
},
{
"localized_name": "start_time",
"name": "start_time",
"type": "FLOAT",
"widget": {
"name": "start_time"
},
"link": 121
},
{
"localized_name": "duration",
"name": "duration",
"type": "FLOAT",
"widget": {
"name": "duration"
},
"link": 122
},
{
"localized_name": "strict_duration",
"name": "strict_duration",
"type": "BOOLEAN",
"widget": {
"name": "strict_duration"
},
"link": null
}
],
"outputs": [
{
"localized_name": "VIDEO",
"name": "VIDEO",
"type": "VIDEO",
"links": [
120
]
}
],
"properties": {
"Node name for S&R": "Video Slice",
"cnr_id": "comfy-core",
"ver": "0.24.0"
},
"widgets_values": [
0,
5,
false
]
}
],
"groups": [],
"links": [
{
"id": 12,
"origin_id": 93,
"origin_slot": 0,
"target_id": 92,
"target_slot": 0,
"type": "DA3_GEOMETRY"
},
{
"id": 7,
"origin_id": 92,
"origin_slot": 0,
"target_id": -20,
"target_slot": 0,
"type": "IMAGE"
},
{
"id": 107,
"origin_id": 94,
"origin_slot": 0,
"target_id": 93,
"target_slot": 0,
"type": "DA3_MODEL"
},
{
"id": 111,
"origin_id": 95,
"origin_slot": 0,
"target_id": 93,
"target_slot": 1,
"type": "IMAGE"
},
{
"id": 112,
"origin_id": 95,
"origin_slot": 1,
"target_id": -20,
"target_slot": 1,
"type": "AUDIO"
},
{
"id": 113,
"origin_id": 95,
"origin_slot": 2,
"target_id": -20,
"target_slot": 2,
"type": "FLOAT"
},
{
"id": 119,
"origin_id": -10,
"origin_slot": 0,
"target_id": 96,
"target_slot": 0,
"type": "VIDEO"
},
{
"id": 120,
"origin_id": 96,
"origin_slot": 0,
"target_id": 95,
"target_slot": 0,
"type": "VIDEO"
},
{
"id": 121,
"origin_id": -10,
"origin_slot": 1,
"target_id": 96,
"target_slot": 1,
"type": "FLOAT"
},
{
"id": 122,
"origin_id": -10,
"origin_slot": 2,
"target_id": 96,
"target_slot": 2,
"type": "FLOAT"
},
{
"id": 124,
"origin_id": -10,
"origin_slot": 3,
"target_id": 93,
"target_slot": 2,
"type": "INT"
},
{
"id": 125,
"origin_id": -10,
"origin_slot": 4,
"target_id": 93,
"target_slot": 3,
"type": "COMBO"
},
{
"id": 126,
"origin_id": -10,
"origin_slot": 5,
"target_id": 92,
"target_slot": 1,
"type": "COMFY_DYNAMICCOMBO_V3"
},
{
"id": 127,
"origin_id": -10,
"origin_slot": 6,
"target_id": 92,
"target_slot": 2,
"type": "COMBO"
},
{
"id": 128,
"origin_id": -10,
"origin_slot": 7,
"target_id": 92,
"target_slot": 3,
"type": "BOOLEAN"
},
{
"id": 129,
"origin_id": -10,
"origin_slot": 8,
"target_id": 94,
"target_slot": 0,
"type": "COMBO"
}
],
"extra": {},
"category": "Conditioning & Preprocessors/Depth",
"description": "This subgraph processes a video input through Depth Anything 3 to produce temporally consistent depth maps for each frame, outputting a depth video. It is ideal for video content requiring spatial geometry estimation, such as 3D reconstruction, SLAM, or novel view synthesis from moving cameras. The model uses a plain transformer backbone trained with a depth-ray representation, supporting any number of views without requiring known camera poses."
}
]
},
"extra": {
"BlueprintDescription": "This subgraph processes a video input through Depth Anything 3 to produce temporally consistent depth maps for each frame, outputting a depth video. It is ideal for video content requiring spatial geometry estimation, such as 3D reconstruction, SLAM, or novel view synthesis from moving cameras. The model uses a plain transformer backbone trained with a depth-ray representation, supporting any number of views without requiring known camera poses."
}
}

File diff suppressed because it is too large Load Diff

View File

@ -115,6 +115,7 @@ cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metav
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--high-ram", action="store_true", help="Can improve performance slightly on high RAM or on systems where pagefile use is preferred over model loading.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -133,7 +134,7 @@ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
manager_group = parser.add_mutually_exclusive_group()
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager. Implies --enable-manager.")
vram_group = parser.add_mutually_exclusive_group()
@ -144,6 +145,7 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn'
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--vram-headroom", type=float, default=0, help="Set the amount of vram in GB for DynamicVRAM to maintain as extra headroom above default. ComfyUI will try and keep this much VRAM completely free and unused, even counting VRAM from other apps.")
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
@ -166,6 +168,8 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.")
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
@ -247,6 +251,9 @@ else:
if args.cache_ram is not None and len(args.cache_ram) > 2:
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
if args.high_ram:
args.cache_classic = True
if args.windows_standalone_build:
args.auto_launch = True
@ -256,6 +263,10 @@ if args.disable_auto_launch:
if args.force_fp16:
args.fp16_unet = True
# '--enable-manager-legacy-ui' is meaningless unless the manager is enabled, so imply '--enable-manager'.
if args.enable_manager_legacy_ui:
args.enable_manager = True
# '--fast' is not provided, use an empty set
if args.fast is None:

View File

@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
import logging
import comfy.model_management
import comfy.patcher_extension
import comfy.utils
import comfy.conds
if TYPE_CHECKING:
from comfy.model_base import BaseModel
from comfy.model_patcher import ModelPatcher
@ -51,12 +53,18 @@ class ContextHandlerABC(ABC):
class IndexListContextWindow(ContextWindowABC):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
self.index_list = index_list
self.context_length = len(index_list)
self.context_overlap = context_overlap
self.dim = dim
self.total_frames = total_frames
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
self.guide_frames_indices: list[int] = []
self.guide_overlap_info: list[tuple[int, int]] = []
self.guide_kf_local_positions: list[int] = []
self.guide_downscale_factors: list[int] = []
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
if dim is None:
@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC):
region_idx = int(self.center_ratio * num_regions)
return min(max(region_idx, 0), num_regions - 1)
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
if modality_idx == 0:
return self
return self.modality_windows[modality_idx]
class IndexListCallbacks:
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
return cond_value._copy_with(sliced)
def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
"""Compute which concatenated guide frames overlap with a context window.
Each guide's latent-space start is derived from its first token's pixel-t-start
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
model's temporal_downscale_ratio.
Args:
guide_entries: list of guide_attention_entry dicts
keyframe_idxs: per-token pixel coords cond tensor for the modality
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
window_index_list: the window's frame indices into the video portion
Returns:
suffix_indices: indices into the guide_frames tensor for frame selection
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
total_overlap: total number of overlapping guide frames
"""
window_set = set(window_index_list)
window_list = list(window_index_list)
suffix_indices = []
overlap_info = []
kf_local_positions = []
suffix_base = 0
token_offset = 0
for entry_idx, entry in enumerate(guide_entries):
first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
guide_len = entry["latent_shape"][0]
entry_overlap = 0
for local_offset in range(guide_len):
video_pos = latent_start + local_offset
if video_pos in window_set:
suffix_indices.append(suffix_base + local_offset)
kf_local_positions.append(window_list.index(video_pos))
entry_overlap += 1
if entry_overlap > 0:
overlap_info.append((entry_idx, entry_overlap))
suffix_base += guide_len
token_offset += entry["pre_filter_count"]
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
@dataclass
class WindowingState:
"""Per-modality context windowing state for each step,
built using IndexListContextHandler._build_window_state().
For non-multimodal models the lists are length 1
"""
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
dim: int = 0 # primary modality temporal dim for context windowing
is_multimodal: bool = False
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
"""Reformat window for multimodal contexts by deriving per-modality index lists.
Non-multimodal contexts return the input window unchanged.
"""
if not self.is_multimodal:
return window
x = self.latents[0]
primary_total = self.latent_shapes[0][self.dim]
primary_overlap = window.context_overlap
map_shapes = self.latent_shapes
if x.size(self.dim) != primary_total:
map_shapes = list(self.latent_shapes)
video_shape = list(self.latent_shapes[0])
video_shape[self.dim] = x.size(self.dim)
map_shapes[0] = torch.Size(video_shape)
try:
per_modality_indices = model.map_context_window_to_modalities(
window.index_list, map_shapes, self.dim)
except AttributeError:
raise NotImplementedError(
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
modality_windows = {}
for mod_idx in range(1, len(self.latents)):
modality_total_frames = self.latents[mod_idx].shape[self.dim]
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
modality_overlap = max(round(primary_overlap * ratio), 0)
modality_windows[mod_idx] = IndexListContextWindow(
per_modality_indices[mod_idx], dim=self.dim,
total_frames=modality_total_frames,
context_overlap=modality_overlap)
return IndexListContextWindow(
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
modality_windows=modality_windows, context_overlap=primary_overlap)
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
"""Slice latents for a context window, injecting guide frames where applicable.
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
"""
sliced = []
guide_frame_counts = []
for idx in range(len(self.latents)):
modality_window = window.get_window_for_modality(idx)
retain = retain_index_list if idx == 0 else []
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
if self.guide_entries[idx] is not None:
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
else:
ng = 0
sliced.append(s)
guide_frame_counts.append(ng)
return sliced, guide_frame_counts
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
for idx in range(len(self.latents)):
if guide_frame_counts[idx] > 0:
window_len = len(window.get_window_for_modality(idx).index_list)
for ci in range(len(out_per_modality)):
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
guide_entries = self.guide_entries[modality_idx]
guide_frames = self.guide_latents[modality_idx]
keyframe_idxs = self.keyframe_idxs[modality_idx]
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
kf_local_pos = [p + 1 for p in kf_local_pos]
window.guide_frames_indices = suffix_idx
window.guide_overlap_info = overlap_info
window.guide_kf_local_positions = kf_local_pos
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
guide_downscale_factors = []
if guide_frame_count > 0:
full_H = guide_frames.shape[3]
for entry_idx, _ in overlap_info:
entry_H = guide_entries[entry_idx]["latent_shape"][1]
guide_downscale_factors.append(full_H // entry_H)
window.guide_downscale_factors = guide_downscale_factors
if guide_frame_count > 0:
idx = tuple([slice(None)] * self.dim + [suffix_idx])
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
return latent_slice, 0
def patch_latent_shapes(self, sub_conds, new_shapes):
if not self.is_multimodal:
return
for cond_list in sub_conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
@dataclass
class ContextSchedule:
name: str
@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
causal_window_fix: bool=True):
latent_retain_index_list: list[int]=[], causal_window_fix: bool=True):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC):
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
self.causal_window_fix = causal_window_fix
self.callbacks = {}
@staticmethod
def _get_latent_shapes(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
if 'latent_shapes' in model_conds:
return model_conds['latent_shapes'].cond
return None
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
@staticmethod
def _get_keyframe_idxs(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
kf = model_conds.get('keyframe_idxs')
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
return kf.cond
return None
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
If guide frames are present on the primary modality, only the video portion is shuffled.
"""
guide_entries = self._get_guide_entries(conds)
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
latent_shapes = self._get_latent_shapes(conds)
if latent_shapes is not None and len(latent_shapes) > 1:
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
primary_total = latent_shapes[0][self.dim]
primary_video_count = modalities[0].size(self.dim) - guide_count
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
for i in range(1, len(modalities)):
mod_total = latent_shapes[i][self.dim]
ratio = mod_total / primary_total if primary_total > 0 else 1
mod_ctx_len = max(round(self.context_length * ratio), 1)
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
noise, _ = comfy.utils.pack_latents(modalities)
return noise
video_count = noise.size(self.dim) - guide_count
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
return noise
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
latent_shapes = self._get_latent_shapes(conds)
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
unpacked_latents_list = list(unpacked_latents)
guide_latents_list = [None] * len(unpacked_latents)
guide_entries_list = [None] * len(unpacked_latents)
keyframe_idxs_list = [None] * len(unpacked_latents)
extracted_guide_entries = self._get_guide_entries(conds)
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
# Strip guide frames (only from first modality for now)
if extracted_guide_entries is not None:
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
if guide_count > 0:
x = unpacked_latents[0]
latent_count = x.size(self.dim) - guide_count
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
guide_entries_list[0] = extracted_guide_entries
keyframe_idxs_list[0] = extracted_keyframe_idxs
return WindowingState(
latents=unpacked_latents_list,
guide_latents=guide_latents_list,
guide_entries=guide_entries_list,
keyframe_idxs=keyframe_idxs_list,
latent_shapes=latent_shapes,
dim=self.dim,
is_multimodal=is_multimodal,
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
if x_in.size(self.dim) > self.context_length:
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
total_frame_count = window_state.latents[0].size(self.dim)
if total_frame_count > self.context_length:
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
if self.cond_retain_index_list:
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
if self.latent_retain_index_list:
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
return True
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
return False
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC):
return resized_cond
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
current_timestep = timestep[0].to(sample_sigmas.dtype)
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
matches = torch.nonzero(mask)
if torch.numel(matches) == 0:
return # substep from multi-step sampler: keep self._step from the last full step
@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC):
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
full_length = x_in.size(self.dim) # TODO: choose dim based on model
context_windows = self.context_schedule.func(full_length, self, model_options)
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model
self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows))
conds_final = [torch.zeros_like(x_in) for _ in conds]
window_state = self._build_window_state(x_in, conds, model)
num_modalities = len(window_state.latents)
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
enumerated_context_windows = list(enumerate(context_windows))
total_windows = len(enumerated_context_windows)
# Initialize per-modality accumulators (length 1 for single-modality)
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
else:
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
# accumulate results from each context window
for enum_window in enumerated_context_windows:
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
results = self.evaluate_context_windows(
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
model_options, window_state=window_state, total_windows=total_windows)
for result in results:
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
conds_final, counts_final, biases_final)
# result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
for mod_idx in range(num_modalities):
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
modality_window = result.window.get_window_for_modality(mod_idx)
self.combine_context_window_results(
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
result.window_idx, total_windows, timestep,
accum[mod_idx], counts[mod_idx], biases[mod_idx])
# fuse accumulated results into final conds
try:
# finalize conds
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
# relative is already normalized, so return as is
del counts_final
return conds_final
else:
# normalize conds via division by context usage counts
for i in range(len(conds_final)):
conds_final[i] /= counts_final[i]
del counts_final
return conds_final
result_out = []
for ci in range(len(conds)):
finalized = []
for mod_idx in range(num_modalities):
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
accum[mod_idx][ci] /= counts[mod_idx][ci]
f = accum[mod_idx][ci]
# if guide frames were injected, append them to the end of the fused latents for the next step
if window_state.guide_latents[mod_idx] is not None:
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
finalized.append(f)
# pack modalities together if needed
if window_state.is_multimodal and len(finalized) > 1:
packed, _ = comfy.utils.pack_latents(finalized)
else:
packed = finalized[0]
result_out.append(packed)
return result_out
finally:
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options)
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, device=None, first_device=None):
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
model_options, window_state: WindowingState, total_windows: int = None,
device=None, first_device=None):
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
For each window:
1. Builds windows (for each modality if multimodal)
2. Slices window for each modality
3. Injects concatenated latent guide frames where present
4. Packs together if needed and calls model
5. Unpacks and strips any guides from outputs
"""
x = window_state.latents[0]
results: list[ContextResults] = []
for window_idx, window in enumerated_context_windows:
# allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted()
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
# prepare the window accounting for multimodal windows
window = window_state.prepare_window(window, model)
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward.
# Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up.
anchor_applied = False
if self.causal_window_fix:
anchor_idx = window.index_list[0] - 1
@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC):
window.causal_anchor_index = anchor_idx
anchor_applied = True
# slice the window for each modality, injecting guide frames where applicable
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
# update exposed params
logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
)
# if multimodal, pack modalities together
if window_state.is_multimodal and len(sliced) > 1:
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
else:
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
# get resized conds for window
model_options["transformer_options"]["context_window"] = window
# get subsections of x, timestep, conds
sub_x = window.get_tensor(x_in, device)
sub_timestep = window.get_tensor(timestep, device, dim=0)
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
sub_timestep = window.get_tensor(timestep, dim=0)
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
# if multimodal, patch latent_shapes in conds for correct unpacking in model
window_state.patch_latent_shapes(sub_conds, sub_shapes)
# call model on window
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
if device is not None:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
# strip causal_window_fix anchor if applied
# unpack outputs
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
# strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct
if anchor_applied:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
for ci in range(len(out_per_modality)):
t = out_per_modality[ci][0]
out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1)
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
# strip injected guide frames
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
return results
@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC):
biases_final[i][idx] = bias_total + bias
else:
# add conds and counts based on weights of fuse method
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
for i in range(len(sub_conds_out)):
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC):
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
# limit noise_shape length to context_length for more accurate vram use estimation
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
# Scale noise_shape to a single context window so VRAM estimation budgets per-window.
model_options = kwargs.get("model_options", None)
if model_options is None:
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
handler: IndexListContextHandler = model_options.get("context_handler", None)
if handler is not None:
noise_shape = list(noise_shape)
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, *args, **kwargs)
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
if is_packed:
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
pass
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
return executor(model, noise_shape, conds, *args, **kwargs)
def create_prepare_sampling_wrapper(model: ModelPatcher):
@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
if not handler.freenoise:
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
def create_sampler_sample_wrapper(model: ModelPatcher):
model.add_wrapper_with_key(
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
_sampler_sample_wrapper
)
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
total_dims = len(x_in.shape)
weights_tensor = torch.Tensor(weights).to(device=device)
@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
return ContextSchedule(context_schedule, func)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
context_overlap = handler.context_overlap if context_overlap is None else context_overlap
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
def create_weights_flat(length: int, **kwargs) -> list[float]:
@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
return weight_sequence
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
# only expected overlap is given different weights
weights_torch = torch.ones((length))
# blend left-side on all except first window
if min(idxs) > 0:
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
weights_torch[:handler.context_overlap] = ramp_up
ramp_up = torch.linspace(1e-37, 1, context_overlap)
weights_torch[:context_overlap] = ramp_up
# blend right-side on all except last window
if max(idxs) < full_length-1:
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
weights_torch[-handler.context_overlap:] = ramp_down
ramp_down = torch.linspace(1, 1e-37, context_overlap)
weights_torch[-context_overlap:] = ramp_down
return weights_torch
class ContextFuseMethods:

View File

@ -1,7 +1,13 @@
import torch
import torch.nn.functional as F
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.depth_anything_3.reference_view_selector import (
select_reference_view, reorder_by_reference, restore_original_order,
THRESH_FOR_REF_SELECTION,
)
class Dino2AttentionOutput(torch.nn.Module):
@ -14,13 +20,41 @@ class Dino2AttentionOutput(torch.nn.Module):
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations,
qk_norm=False):
super().__init__()
self.heads = heads
self.head_dim = embed_dim // heads
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
if qk_norm:
self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
else:
self.q_norm = None
self.k_norm = None
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
def forward(self, x, mask, optimized_attention, pos=None, rope=None):
# Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions).
if self.q_norm is None and rope is None:
return self.output(self.attention(x, mask, optimized_attention))
# DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE.
attn = self.attention
B, N, C = x.shape
h = self.heads
d = self.head_dim
q = attn.query(x).view(B, N, h, d).transpose(1, 2)
k = attn.key(x).view(B, N, h, d).transpose(1, 2)
v = attn.value(x).view(B, N, h, d).transpose(1, 2)
if self.q_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)
if rope is not None and pos is not None:
q = rope(q, pos)
k = rope(k, pos)
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
return self.output(out)
class LayerScale(torch.nn.Module):
@ -64,9 +98,11 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn,
qk_norm=False):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations,
qk_norm=qk_norm)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
if use_swiglu_ffn:
@ -76,19 +112,90 @@ class Dino2Block(torch.nn.Module):
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None):
x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention,
pos=pos, rope=rope))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
# -----------------------------------------------------------------------------
# 2D Rotary position embedding (DA3 extension)
# -----------------------------------------------------------------------------
class _PositionGetter:
"""Cache (h, w) -> flat (y, x) position grid used to feed ``rope``."""
def __init__(self):
self._cache: dict = {}
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
key = (height, width, device)
if key not in self._cache:
y = torch.arange(height, device=device)
x = torch.arange(width, device=device)
self._cache[key] = torch.cartesian_prod(y, x)
cached = self._cache[key]
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(torch.nn.Module):
"""2D RoPE used by DA3-Small/Base. No learnable parameters."""
def __init__(self, frequency: float = 100.0):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
for _ in range(num_layers)])
self.base_frequency = frequency
self._freq_cache: dict = {}
def _components(self, dim: int, seq_len: int, device, dtype):
key = (dim, seq_len, device, dtype)
if key not in self._freq_cache:
exp = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency ** exp)
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
ang = torch.einsum("i,j->ij", pos, inv_freq)
ang = ang.to(dtype)
ang = torch.cat((ang, ang), dim=-1)
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
return self._freq_cache[key]
@staticmethod
def _rotate(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1]
x1, x2 = x[..., : d // 2], x[..., d // 2:]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d(self, tokens, positions, cos_c, sin_c):
cos = F.embedding(positions, cos_c)[:, None, :, :]
sin = F.embedding(positions, sin_c)[:, None, :, :]
return (tokens * cos) + (self._rotate(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
feature_dim = tokens.size(-1) // 2
max_pos = int(positions.max()) + 1
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
v, h = tokens.chunk(2, dim=-1)
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
return torch.cat((v, h), dim=-1)
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn,
qknorm_start: int = -1):
super().__init__()
self.layer = torch.nn.ModuleList([
Dino2Block(
dim, num_heads, layer_norm_eps, dtype, device, operations,
use_swiglu_ffn=use_swiglu_ffn,
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
)
for i in range(num_layers)
])
def forward(self, x, intermediate_output=None):
# Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions).
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
@ -122,16 +229,27 @@ class Dino2PatchEmbeddings(torch.nn.Module):
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
def __init__(self, dim, dtype, device, operations,
patch_size: int = 14, image_size: int = 518,
use_mask_token: bool = True,
num_camera_tokens: int = 0):
super().__init__()
patch_size = 14
image_size = 518
self.patch_size = patch_size
self.image_size = image_size
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
if use_mask_token:
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
else:
self.mask_token = None
if num_camera_tokens > 0:
# DA3 stores (ref_token, src_token) pairs that get injected at the
# alt-attn boundary; see ``Dinov2Model._inject_camera_token``.
self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device))
else:
self.camera_token = None
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
@ -140,12 +258,22 @@ class Dino2Embeddings(torch.nn.Module):
patch_pos = pos_embed[:, 1:]
N = patch_pos.shape[1]
M = int(N ** 0.5)
assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})"
h0 = h_pixels // self.patch_size
w0 = w_pixels // self.patch_size
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
# +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
# scale_factor is (height_scale, width_scale) -- height MUST come first;
# swapping these only happens to work for square inputs and breaks
# non-square paths like DA3-Small / DA3-Base multi-view.
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
assert (h0, w0) == patch_pos.shape[-2:], (
f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match "
f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); "
f"check scale_factor axis order and +0.1 rounding workaround"
)
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
@ -168,12 +296,51 @@ class Dinov2Model(torch.nn.Module):
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
patch_size = config_dict.get("patch_size", 14)
image_size = config_dict.get("image_size", 518)
use_mask_token = config_dict.get("use_mask_token", True)
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
# DA3 extensions (all default to disabled).
self.alt_start = config_dict.get("alt_start", -1)
self.qknorm_start = config_dict.get("qknorm_start", -1)
self.rope_start = config_dict.get("rope_start", -1)
self.cat_token = config_dict.get("cat_token", False)
rope_freq = config_dict.get("rope_freq", 100.0)
self.embed_dim = dim
self.patch_size = patch_size
self.num_register_tokens = 0
self.patch_start_idx = 1
if self.rope_start != -1 and rope_freq > 0:
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
self._position_getter = _PositionGetter()
else:
self.rope = None
self._position_getter = None
# camera_token shape: (1, 2, dim) -> (ref_token, src_token).
num_cam_tokens = 2 if self.alt_start != -1 else 0
self.embeddings = Dino2Embeddings(
dim, dtype, device, operations,
patch_size=patch_size, image_size=image_size,
use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens,
)
self.encoder = Dino2Encoder(
dim, heads, layer_norm_eps, num_layers, dtype, device, operations,
use_swiglu_ffn=use_swiglu_ffn,
qknorm_start=self.qknorm_start,
)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
if self.alt_start != -1:
raise RuntimeError(
"Dinov2Model.forward() is the backward-compatible CLIP-vision path and does not "
"apply DA3 extensions (RoPE, alternating attention, camera-token injection). "
"Use get_intermediate_layers_da3() for Depth Anything 3 models."
)
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
@ -181,6 +348,7 @@ class Dinov2Model(torch.nn.Module):
return x, i, pooled_output, None
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
"""Single-view multi-layer feature extraction."""
x = self.embeddings(pixel_values)
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
n_layers = len(self.encoder.layer)
@ -197,3 +365,132 @@ class Dinov2Model(torch.nn.Module):
if i >= max_idx:
break
return [cache[i] for i in resolved]
# ------------------------------------------------------------------
# Depth Anything 3 forward
# ------------------------------------------------------------------
def _prepare_rope_positions(self, B, S, H, W, device):
if self.rope is None:
return None, None
ph, pw = H // self.patch_size, W // self.patch_size
pos = self._position_getter(B * S, ph, pw, device=device)
# Shift so the cls/cam token at position 0 is reserved for "no diff".
pos = pos + 1
cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
# Per-view local: real grid positions for patches, 0 for cls token.
pos_local = torch.cat([cls_pos, pos], dim=1)
# Global (across views): same grid positions; cls token still at 0,
# but patches share the same positions in every view.
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
return pos_local, pos_global
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, cam_token: "torch.Tensor | None") -> torch.Tensor:
# x: (B, S, N, C). Replace token at index 0 with the camera token.
if cam_token is not None:
inj = cam_token
else:
ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype)
ref_token = ct[:, :1].expand(B, -1, -1)
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
inj = torch.cat([ref_token, src_token], dim=1)
x = x.clone()
x[:, :, 0] = inj
return x
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None, ref_view_strategy="saddle_balanced", export_feat_layers=None):
"""Multi-view multi-layer feature extraction used by Depth Anything 3."""
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(1)
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}"
B, S, _, H, W = pixel_values.shape
# Patch + cls + (interpolated) pos embed for each view.
x = pixel_values.reshape(B * S, 3, H, W)
x = self.embeddings(x) # (B*S, 1+N, C)
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
# optimized_attention is only used by blocks without QK-norm/RoPE
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
out_set = set(out_layers)
export_set = set(export_feat_layers) if export_feat_layers else set()
outputs: list[torch.Tensor] = []
aux_outputs: list[torch.Tensor] = []
local_x = x
b_idx = None
for i, blk in enumerate(self.encoder.layer):
apply_rope = self.rope is not None and i >= self.rope_start
block_rope = self.rope if apply_rope else None
l_pos = pos_local if apply_rope else None
g_pos = pos_global if apply_rope else None
# Reference-view selection threshold: matches the upstream constant
# THRESH_FOR_REF_SELECTION = 3. Skipped when a user-supplied
# cam_token is provided (camera info already pins the geometry).
if (self.alt_start != -1 and i == self.alt_start - 1 and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
b_idx = select_reference_view(x, strategy=ref_view_strategy)
x = reorder_by_reference(x, b_idx)
local_x = reorder_by_reference(local_x, b_idx)
if self.alt_start != -1 and i == self.alt_start:
x = self._inject_camera_token(x, B, S, cam_token)
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
# Global attention across views: flatten S into the seq dim.
t = x.reshape(B, S * x.shape[-2], x.shape[-1])
p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
else:
# Per-view local attention.
t = x.reshape(B * S, x.shape[-2], x.shape[-1])
p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
local_x = x
if i in out_set:
if self.cat_token:
out_x = torch.cat([local_x, x], dim=-1)
else:
out_x = x
# Restore original view order on the way out so heads see views
# in the user's expected order.
if b_idx is not None and self.alt_start != -1:
out_x = restore_original_order(out_x, b_idx)
outputs.append(out_x)
if i in export_set:
aux = x
if b_idx is not None and self.alt_start != -1:
aux = restore_original_order(aux, b_idx)
aux_outputs.append(aux)
# Apply final norm. When cat_token is set, only the right half
# ("global" features) is normalised; the left half is left as-is to
# match the upstream DA3 head signature.
normed: list[torch.Tensor] = []
cls_tokens: list[torch.Tensor] = []
for out_x in outputs:
cls_tokens.append(out_x[:, :, 0])
if out_x.shape[-1] == self.embed_dim:
normed.append(self.layernorm(out_x))
elif out_x.shape[-1] == self.embed_dim * 2:
left = out_x[..., :self.embed_dim]
right = self.layernorm(out_x[..., self.embed_dim:])
normed.append(torch.cat([left, right], dim=-1))
else:
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
# Drop cls/cam token from the patch sequence.
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
# Final layernorm + drop cls token from auxiliary features too.
aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :]
for o in aux_outputs]
return list(zip(normed, cls_tokens)), aux_normed

321
comfy/ldm/boogu/model.py Normal file
View File

@ -0,0 +1,321 @@
# Boogu-Image-0.1 transformer
# Architecture is an OmniGen2 derivative (see comfy/ldm/omnigen/omnigen2.py) with an
# added dual-stream ("double_stream") stage before the single-stream layers, conditioned
# by a Qwen3-VL multimodal LLM. Reuses the OmniGen2/Lumina building blocks and the Flux
# RoPE core, the only new component is the double-stream block + the hybrid forward order.
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
import comfy.ldm.common_dit
import comfy.ldm.omnigen.omnigen2
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.omnigen.omnigen2 import (
OmniGen2RotaryPosEmbed,
Lumina2CombinedTimestepCaptionEmbedding,
LuminaRMSNormZero,
LuminaLayerNormContinuous,
LuminaFeedForward,
Attention,
OmniGen2TransformerBlock,
apply_rotary_emb,
)
class BooguDoubleStreamProcessor(nn.Module):
# Joint attention over [instruct ; img] with separate per-stream q/k/v and output projections.
def __init__(self, dim, head_dim, heads, kv_heads, dtype=None, device=None, operations=None):
super().__init__()
query_dim = head_dim * heads
kv_dim = head_dim * kv_heads
self.img_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
self.img_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.img_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.instruct_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
self.instruct_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.instruct_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
self.instruct_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
self.img_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
def forward(self, attn, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}):
batch_size = img_hidden_states.shape[0]
L_instruct = instruct_hidden_states.shape[1]
img_q = self.img_to_q(img_hidden_states)
img_k = self.img_to_k(img_hidden_states)
img_v = self.img_to_v(img_hidden_states)
instruct_q = self.instruct_to_q(instruct_hidden_states)
instruct_k = self.instruct_to_k(instruct_hidden_states)
instruct_v = self.instruct_to_v(instruct_hidden_states)
# Concatenate instruction first, then image (matches reference processor order).
query = torch.cat([instruct_q, img_q], dim=1)
key = torch.cat([instruct_k, img_k], dim=1)
value = torch.cat([instruct_v, img_v], dim=1)
query = query.view(batch_size, -1, attn.heads, attn.dim_head)
key = key.view(batch_size, -1, attn.kv_heads, attn.dim_head)
value = value.view(batch_size, -1, attn.kv_heads, attn.dim_head)
query = attn.norm_q(query)
key = attn.norm_k(key)
if rotary_emb is not None:
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
if attn.kv_heads < attn.heads:
key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
# Split back to instruction/image, apply per-stream output projections, recombine.
instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct])
img_hidden_states = self.img_out(hidden_states[:, L_instruct:])
hidden_states = torch.cat([instruct_hidden_states, img_hidden_states], dim=1)
hidden_states = attn.to_out[0](hidden_states)
return hidden_states
class BooguJointAttention(nn.Module):
# Holds the shared q/k RMSNorm + final output projection
def __init__(self, dim, head_dim, heads, kv_heads, eps=1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.heads = heads
self.kv_heads = kv_heads
self.dim_head = head_dim
self.scale = head_dim ** -0.5
self.norm_q = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(heads * head_dim, dim, bias=False, dtype=dtype, device=device),
nn.Dropout(0.0),
)
self.processor = BooguDoubleStreamProcessor(dim, head_dim, heads, kv_heads, dtype=dtype, device=device, operations=operations)
def forward(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}):
return self.processor(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask, transformer_options=transformer_options)
class BooguDoubleStreamBlock(nn.Module):
# Dual-stream block: joint attention over [instruct ; img] + image self-attention, each stream with its own modulation/MLP.
def __init__(self, dim, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=None, device=None, operations=None):
super().__init__()
head_dim = dim // num_attention_heads
self.img_instruct_attn = BooguJointAttention(dim, head_dim, num_attention_heads, num_kv_heads, eps=1e-5, dtype=dtype, device=device, operations=operations)
self.img_self_attn = Attention(
query_dim=dim, dim_head=head_dim, heads=num_attention_heads, kv_heads=num_kv_heads,
eps=1e-5, bias=False, dtype=dtype, device=device, operations=operations,
)
self.img_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations)
self.instruct_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations)
self.img_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.img_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.img_norm3 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.instruct_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.instruct_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.img_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.img_self_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.img_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.img_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.instruct_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.instruct_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
self.instruct_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
def forward(self, img_hidden_states, instruct_hidden_states, joint_rotary_emb, img_rotary_emb, temb, joint_attention_mask=None, img_attention_mask=None, transformer_options={}):
L_instruct = instruct_hidden_states.shape[1]
img_norm1_out, img_gate_msa, img_scale_mlp, img_gate_mlp = self.img_norm1(img_hidden_states, temb)
img_norm2_out, img_shift_mlp, _, _ = self.img_norm2(img_hidden_states, temb)
img_norm3_out, img_gate_self, _, _ = self.img_norm3(img_hidden_states, temb)
instruct_norm1_out, instruct_gate_msa, instruct_scale_mlp, instruct_gate_mlp = self.instruct_norm1(instruct_hidden_states, temb)
instruct_norm2_out, instruct_shift_mlp, _, _ = self.instruct_norm2(instruct_hidden_states, temb)
joint_attn_out = self.img_instruct_attn(img_norm1_out, instruct_norm1_out, joint_rotary_emb, joint_attention_mask, transformer_options=transformer_options)
instruct_attn_out = joint_attn_out[:, :L_instruct]
img_attn_out = joint_attn_out[:, L_instruct:]
img_self_attn_out = self.img_self_attn(img_norm3_out, img_norm3_out, img_attention_mask, img_rotary_emb, transformer_options=transformer_options)
img_hidden_states = img_hidden_states + img_gate_msa.unsqueeze(1).tanh() * self.img_attn_norm(img_attn_out)
img_hidden_states = img_hidden_states + img_gate_self.unsqueeze(1).tanh() * self.img_self_attn_norm(img_self_attn_out)
img_mlp_input = (1 + img_scale_mlp.unsqueeze(1)) * img_norm2_out + img_shift_mlp.unsqueeze(1)
img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_mlp_input))
img_hidden_states = img_hidden_states + img_gate_mlp.unsqueeze(1).tanh() * self.img_ffn_norm2(img_mlp_out)
instruct_hidden_states = instruct_hidden_states + instruct_gate_msa.unsqueeze(1).tanh() * self.instruct_attn_norm(instruct_attn_out)
instruct_mlp_input = (1 + instruct_scale_mlp.unsqueeze(1)) * instruct_norm2_out + instruct_shift_mlp.unsqueeze(1)
instruct_mlp_out = self.instruct_feed_forward(self.instruct_ffn_norm1(instruct_mlp_input))
instruct_hidden_states = instruct_hidden_states + instruct_gate_mlp.unsqueeze(1).tanh() * self.instruct_ffn_norm2(instruct_mlp_out)
return img_hidden_states, instruct_hidden_states
class BooguTransformer2DModel(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: Optional[int] = None,
hidden_size: int = 3360,
num_layers: int = 32,
num_double_stream_layers: int = 8,
num_refiner_layers: int = 2,
num_attention_heads: int = 28,
num_kv_heads: int = 7,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
axes_dim_rope: Tuple[int, int, int] = (40, 40, 40),
axes_lens: Tuple[int, int, int] = (2048, 1664, 1664),
instruction_feat_dim: int = 4096,
timestep_scale: float = 1000.0,
image_model=None,
device=None, dtype=None, operations=None,
):
super().__init__()
self.patch_size = patch_size
self.out_channels = out_channels or in_channels
self.hidden_size = hidden_size
self.dtype = dtype
self.rope_embedder = OmniGen2RotaryPosEmbed(
theta=10000,
axes_dim=axes_dim_rope,
axes_lens=axes_lens,
patch_size=patch_size,
)
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
hidden_size=hidden_size,
text_feat_dim=instruction_feat_dim,
norm_eps=norm_eps,
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
)
self.noise_refiner = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
for _ in range(num_refiner_layers)
])
self.ref_image_refiner = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
for _ in range(num_refiner_layers)
])
self.context_refiner = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations)
for _ in range(num_refiner_layers)
])
self.double_stream_layers = nn.ModuleList([
BooguDoubleStreamBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=dtype, device=device, operations=operations)
for _ in range(num_double_stream_layers)
])
self.single_stream_layers = nn.ModuleList([
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
for _ in range(num_layers)
])
self.norm_out = LuminaLayerNormContinuous(
embedding_dim=hidden_size,
conditioning_embedding_dim=min(hidden_size, 1024),
elementwise_affine=False,
eps=1e-6,
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
)
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
# Patchify/refine helpers are identical to OmniGen2; reuse via bound methods.
flat_and_pad_to_seq = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.flat_and_pad_to_seq
img_patch_embed_and_refine = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.img_patch_embed_and_refine
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
B, C, H, W = x.shape
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
_, _, H_padded, W_padded = hidden_states.shape
timestep = 1.0 - timesteps
text_hidden_states = context
text_attention_mask = attention_mask
ref_image_hidden_states = ref_latents
device = hidden_states.device
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes,
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
(
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
rotary_emb, encoder_seq_lengths, seq_lengths,
) = self.rope_embedder(
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
l_effective_ref_img_len, l_effective_img_len,
ref_img_sizes, img_sizes, device,
)
for layer in self.context_refiner:
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
img_len = hidden_states.shape[1]
combined_img_hidden_states = self.img_patch_embed_and_refine(
hidden_states, ref_image_hidden_states,
img_mask, ref_img_mask,
noise_rotary_emb, ref_img_rotary_emb,
l_effective_ref_img_len, l_effective_img_len,
temb,
transformer_options=transformer_options,
)
# Double-stream stage: the image self-attention only sees the [ref ; noise] tokens,
# which sit after the instruction tokens in the joint rope.
L_instruct = text_hidden_states.shape[1]
combined_img_rotary_emb = rotary_emb[:, L_instruct:]
for layer in self.double_stream_layers:
combined_img_hidden_states, text_hidden_states = layer(
combined_img_hidden_states, text_hidden_states,
rotary_emb, combined_img_rotary_emb, temb,
joint_attention_mask=None, img_attention_mask=None,
transformer_options=transformer_options,
)
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
for layer in self.single_stream_layers:
hidden_states = layer(hidden_states, None, rotary_emb, temb, transformer_options=transformer_options)
hidden_states = self.norm_out(hidden_states, temb)
p = self.patch_size
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded // p, p1=p, p2=p)[:, :, :H, :W]
return -output

25
comfy/ldm/colormap.py Normal file
View File

@ -0,0 +1,25 @@
"""Colormap utilities for depth and geometry visualisation."""
from __future__ import annotations
import torch
def turbo(x: torch.Tensor) -> torch.Tensor:
"""Anton Mikhailov polynomial approximation of the Turbo colormap.
Args:
x: Float tensor with values in [0, 1].
Returns:
RGB tensor of the same shape as ``x`` with a trailing size-3 dimension.
"""
x = x.clamp(0.0, 1.0)
x2 = x * x
x3 = x2 * x
x4 = x2 * x2
x5 = x4 * x
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)

View File

@ -515,7 +515,7 @@ class Block(nn.Module):
h=H,
w=W,
)
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
def _x_fn(
_x_B_T_H_W_D: torch.Tensor,
@ -548,7 +548,7 @@ class Block(nn.Module):
shift_cross_attn_B_T_1_1_D,
transformer_options=transformer_options,
)
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
normalized_x_B_T_H_W_D = _fn(
x_B_T_H_W_D,
@ -557,7 +557,7 @@ class Block(nn.Module):
shift_mlp_B_T_1_1_D,
)
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
return x_B_T_H_W_D

View File

@ -0,0 +1,177 @@
"""Camera-token encoder and decoder for Depth Anything 3."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention_for_device
from .transform import affine_inverse, extri_intri_to_pose_encoding
# -----------------------------------------------------------------------
# Building blocks (mirror depth_anything_3.model.utils.{attention,block})
# -----------------------------------------------------------------------
class _Mlp(nn.Module):
"""Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``."""
def __init__(self, in_features, hidden_features=None, out_features=None, *, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, bias=True, device=device, dtype=dtype)
self.fc2 = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class _LayerScale(nn.Module):
"""Per-channel learnable scaling. Matches upstream LayerScale."""
def __init__(self, dim, *, device=None, dtype=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * self.gamma.to(dtype=x.dtype, device=x.device)
class _Attention(nn.Module):
""" Self-attention with fused QKV projection. Mirrors upstream utils.attention.Attention;
Layout matches the HF safetensors (attn.qkv.{weight,bias} and attn.proj.{weight,bias})."""
def __init__(self, dim, num_heads, *, device=None, dtype=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2) # each (B, N, C)
attn_fn = optimized_attention_for_device(x.device, small_input=True)
out = attn_fn(q, k, v, heads=self.num_heads)
return self.proj(out)
class _Block(nn.Module):
"""Pre-norm transformer block with LayerScale. Used by :class:CameraEnc. Layout follows upstream utils.block.Block."""
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, *, device=None, dtype=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.attn = _Attention(dim, num_heads, device=device, dtype=dtype, operations=operations)
self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
def forward(self, x):
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class CameraEnc(nn.Module):
"""Encode per-view (extrinsics, intrinsics) into a camera token.
Maps a 9-D pose-encoding vector through a small MLP up to the backbone's
``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output
has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start``
of the DINOv2 backbone in place of the cls token.
Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly.
"""
def __init__(
self,
dim_out: int = 1024,
dim_in: int = 9,
trunk_depth: int = 4,
target_dim: int = 9,
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
*,
device=None, dtype=None, operations=None,
**_kwargs,
):
super().__init__()
self.target_dim = target_dim
self.trunk_depth = trunk_depth
self.trunk = nn.Sequential(*[
_Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio,
init_values=init_values,
device=device, dtype=dtype, operations=operations)
for _ in range(trunk_depth)
])
self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
self.pose_branch = _Mlp(
in_features=dim_in,
hidden_features=dim_out // 2,
out_features=dim_out,
device=device, dtype=dtype, operations=operations,
)
def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor,
image_size_hw) -> torch.Tensor:
"""Encode camera parameters into ``(B, S, dim_out)`` tokens."""
c2ws = affine_inverse(extrinsics)
pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw)
tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype))
tokens = self.token_norm(tokens)
tokens = self.trunk(tokens)
tokens = self.trunk_norm(tokens)
return tokens
class CameraDec(nn.Module):
"""Decode the final cam token into a 9-D pose encoding.
Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is
always predicted by the network; the quaternion and FoV can either be
predicted or supplied via ``camera_encoding`` (used at training time
when GT cameras are available -- not exercised at inference here).
Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly.
"""
def __init__(self, dim_in: int = 1536,
*, device=None, dtype=None, operations=None, **_kwargs):
super().__init__()
d = dim_in
self.backbone = nn.Sequential(
operations.Linear(d, d, device=device, dtype=dtype),
nn.ReLU(),
operations.Linear(d, d, device=device, dtype=dtype),
nn.ReLU(),
)
self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype)
self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype)
self.fc_fov = nn.Sequential(
operations.Linear(d, 2, device=device, dtype=dtype),
nn.ReLU(),
)
def forward(self, feat: torch.Tensor,
camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor:
"""Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc."""
B, N = feat.shape[:2]
feat = feat.reshape(B * N, -1)
feat = self.backbone(feat)
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
if camera_encoding is None:
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
else:
out_qvec = camera_encoding[..., 3:7]
out_fov = camera_encoding[..., -2:]
return torch.cat([out_t, out_qvec, out_fov], dim=-1)

View File

@ -0,0 +1,489 @@
"""DPT / DualDPT heads for Depth Anything 3."""
from __future__ import annotations
from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class Permute(nn.Module):
def __init__(self, dims: Tuple[int, ...]):
super().__init__()
self.dims = dims
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute(*self.dims)
def _custom_interpolate(
x: torch.Tensor,
size: Optional[Tuple[int, int]] = None,
scale_factor: Optional[float] = None,
mode: str = "bilinear",
align_corners: bool = True,
) -> torch.Tensor:
if size is None:
assert scale_factor is not None
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
INT_MAX = 1610612736
total = size[0] * size[1] * x.shape[0] * x.shape[1]
if total > INT_MAX:
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks]
return torch.cat(outs, dim=0).contiguous()
return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
def _create_uv_grid(width: int, height: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
"""Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span)."""
diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5
span_x = aspect_ratio / diag_factor
span_y = 1.0 / diag_factor
left_x = -span_x * (width - 1) / width
right_x = span_x * (width - 1) / width
top_y = -span_y * (height - 1) / height
bottom_y = span_y * (height - 1) / height
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
return torch.stack((uu, vv), dim=-1) # (H, W, 2)
def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor:
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0))
pos = pos.reshape(-1)
out = torch.einsum("m,d->md", pos, omega)
return torch.cat([out.sin(), out.cos()], dim=1).float()
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100.0) -> torch.Tensor:
H, W, _ = pos_grid.shape
pos_flat = pos_grid.reshape(-1, 2)
emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)
emb = torch.cat([emb_x, emb_y], dim=-1)
return emb.view(H, W, embed_dim)
def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
"""Stateless UV positional embedding added to a feature map (B, C, h, w)."""
pw, ph = x.shape[-1], x.shape[-2]
pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
pe = _position_grid_to_embed(pe, x.shape[1]) * ratio
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype)
return x + pe
def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor:
act = (activation or "linear").lower()
if act == "exp":
return torch.exp(x)
if act == "expp1":
return torch.exp(x) + 1
if act == "expm1":
return torch.expm1(x)
if act == "relu":
return torch.relu(x)
if act == "sigmoid":
return torch.sigmoid(x)
if act == "softplus":
return F.softplus(x)
if act == "tanh":
return torch.tanh(x)
return x
# -----------------------------------------------------------------------------
# Fusion building blocks
# -----------------------------------------------------------------------------
class ResidualConvUnit(nn.Module):
def __init__(self, features: int, device=None, dtype=None, operations=None):
super().__init__()
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
self.activation = nn.ReLU(inplace=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.activation(x)
out = self.conv1(out)
out = self.activation(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
def __init__(self, features: int, has_residual: bool = True, align_corners: bool = True, device=None, dtype=None, operations=None):
super().__init__()
self.align_corners = align_corners
self.has_residual = has_residual
if has_residual:
self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
else:
self.resConfUnit1 = None
self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, device=device, dtype=dtype)
def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
y = xs[0]
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
y = y + self.resConfUnit1(xs[1])
y = self.resConfUnit2(y)
if size is None:
up_kwargs = {"scale_factor": 2.0}
else:
up_kwargs = {"size": size}
y = _custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
y = self.out_conv(y)
return y
class _Scratch(nn.Module):
"""Container that mirrors upstream ``scratch`` attribute layout."""
def _make_scratch(in_shape: List[int], out_shape: int, device=None, dtype=None, operations=None) -> _Scratch:
scratch = _Scratch()
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
return scratch
def _make_fusion_block(features: int, has_residual: bool = True, device=None, dtype=None, operations=None) -> FeatureFusionBlock:
return FeatureFusionBlock(features, has_residual=has_residual, align_corners=True, device=device, dtype=dtype, operations=operations)
# -----------------------------------------------------------------------------
# DPT (single head + optional sky head) -- used by DA3Mono/Metric
# -----------------------------------------------------------------------------
class DPT(nn.Module):
"""Single-head DPT used by DA3Mono-Large and DA3Metric-Large."""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 1,
activation: str = "exp",
conf_activation: str = "expp1",
features: int = 256,
out_channels: Sequence[int] = (256, 512, 1024, 1024),
pos_embed: bool = False,
down_ratio: int = 1,
head_name: str = "depth",
use_sky_head: bool = True,
sky_name: str = "sky",
sky_activation: str = "relu",
norm_type: str = "idt",
device=None, dtype=None, operations=None,
):
super().__init__()
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.down_ratio = down_ratio
self.head_main = head_name
self.sky_name = sky_name
self.out_dim = output_dim
self.has_conf = output_dim > 1
self.use_sky_head = use_sky_head
self.sky_activation = sky_activation
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
if norm_type == "layer":
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
else:
self.norm = nn.Identity()
out_channels = list(out_channels)
self.projects = nn.ModuleList([
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
for oc in out_channels
])
self.resize_layers = nn.ModuleList([
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
nn.Identity(),
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
])
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = operations.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype,
)
self.scratch.output_conv2 = nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
)
if self.use_sky_head:
self.scratch.sky_output_conv2 = nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
)
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
# feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C)
B, S, N, C = feats[0][0].shape
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
ph, pw = H // self.patch_size, W // self.patch_size
resized = []
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
x = feats_flat[take_idx][:, patch_start_idx:]
x = self.norm(x)
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
x = self.projects[stage_idx](x)
if self.pos_embed:
x = _add_pos_embed(x, W, H)
x = self.resize_layers[stage_idx](x)
resized.append(x)
l1_rn = self.scratch.layer1_rn(resized[0])
l2_rn = self.scratch.layer2_rn(resized[1])
l3_rn = self.scratch.layer3_rn(resized[2])
l4_rn = self.scratch.layer4_rn(resized[3])
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
out = self.scratch.refinenet1(out, l1_rn)
h_out = int(ph * self.patch_size / self.down_ratio)
w_out = int(pw * self.patch_size / self.down_ratio)
fused = self.scratch.output_conv1(out)
fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
if self.pos_embed:
fused = _add_pos_embed(fused, W, H)
feat = fused
main_logits = self.scratch.output_conv2(feat)
outs = {}
if self.has_conf:
fmap = main_logits.permute(0, 2, 3, 1)
pred = _apply_activation(fmap[..., :-1], self.activation)
conf = _apply_activation(fmap[..., -1], self.conf_activation)
outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1])
outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:])
else:
pred = _apply_activation(main_logits, self.activation)
outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:])
if self.use_sky_head:
sky_logits = self.scratch.sky_output_conv2(feat)
if self.sky_activation.lower() == "sigmoid":
sky = torch.sigmoid(sky_logits)
elif self.sky_activation.lower() == "relu":
sky = F.relu(sky_logits)
else:
sky = sky_logits
outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:])
return outs
# -----------------------------------------------------------------------------
# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base
# -----------------------------------------------------------------------------
class DualDPT(nn.Module):
"""Two-head DPT used by DA3-Small / DA3-Base."""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 2,
activation: str = "exp",
conf_activation: str = "expp1",
features: int = 256,
out_channels: Sequence[int] = (256, 512, 1024, 1024),
pos_embed: bool = True,
down_ratio: int = 1,
aux_pyramid_levels: int = 4,
aux_out1_conv_num: int = 5,
head_names: Tuple[str, str] = ("depth", "ray"),
device=None, dtype=None, operations=None,
):
super().__init__()
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.down_ratio = down_ratio
self.aux_levels = aux_pyramid_levels
self.aux_out1_conv_num = aux_out1_conv_num
self.head_main, self.head_aux = head_names
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
# Toggle the auxiliary ray branch at runtime. Default off (mono path).
# DepthAnything3Net flips this on when running multi-view + ray-pose.
self.enable_aux: bool = False
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
out_channels = list(out_channels)
self.projects = nn.ModuleList([
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
for oc in out_channels
])
self.resize_layers = nn.ModuleList([
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
nn.Identity(),
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
])
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
# Main fusion chain
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
# Auxiliary fusion chain (separate copies)
self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
head_features_1 = features
head_features_2 = 32
# Main head neck + final projection
self.scratch.output_conv1 = operations.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype,
)
self.scratch.output_conv2 = nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
)
# Aux pre-head per level (multi-level pyramid)
self.scratch.output_conv1_aux = nn.ModuleList([
self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations)
for _ in range(self.aux_levels)
])
# Aux final projection per level (includes LayerNorm permute path).
ln_seq = [Permute((0, 2, 3, 1)),
operations.LayerNorm(head_features_2, device=device, dtype=dtype),
Permute((0, 3, 1, 2))]
self.scratch.output_conv2_aux = nn.ModuleList([
nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
*ln_seq,
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
)
for _ in range(self.aux_levels)
])
@staticmethod
def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential:
# aux_out1_conv_num=5 in all Apache-2.0 variants.
return nn.Sequential(
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
)
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
B, S, N, C = feats[0][0].shape
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
ph, pw = H // self.patch_size, W // self.patch_size
resized = []
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
x = feats_flat[take_idx][:, patch_start_idx:]
x = self.norm(x)
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
x = self.projects[stage_idx](x)
if self.pos_embed:
x = _add_pos_embed(x, W, H)
x = self.resize_layers[stage_idx](x)
resized.append(x)
l1_rn = self.scratch.layer1_rn(resized[0])
l2_rn = self.scratch.layer2_rn(resized[1])
l3_rn = self.scratch.layer3_rn(resized[2])
l4_rn = self.scratch.layer4_rn(resized[3])
# Main pyramid (output_conv1 is applied inside the upstream `_fuse`,
# before interpolation -- replicate that order here).
m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
if self.enable_aux:
a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
aux_pyr = [a4]
m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:])
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:]))
m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:])
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:]))
m = self.scratch.refinenet1(m, l1_rn)
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn))
m = self.scratch.output_conv1(m)
h_out = int(ph * self.patch_size / self.down_ratio)
w_out = int(pw * self.patch_size / self.down_ratio)
m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True)
if self.pos_embed:
m = _add_pos_embed(m, W, H)
main_logits = self.scratch.output_conv2(m)
fmap = main_logits.permute(0, 2, 3, 1)
depth_pred = _apply_activation(fmap[..., :-1], self.activation)
depth_conf = _apply_activation(fmap[..., -1], self.conf_activation)
outs = {
self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]),
f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]),
}
if self.enable_aux:
# Auxiliary "ray" head (multi-level inside) -- only the last level
# is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``:
# each aux pyramid level goes through ``output_conv1_aux[i]``
# (5-layer conv stack that ends at ``features // 2`` channels),
# then the last level optionally gets a pos-embed and finally
# ``output_conv2_aux[-1]``.
aux_processed = [
self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr)
]
last_aux = aux_processed[-1]
if self.pos_embed:
last_aux = _add_pos_embed(last_aux, W, H)
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
# Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation.
aux_pred = fmap_last[..., :-1]
aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation)
outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:])
outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:])
return outs

View File

@ -0,0 +1,236 @@
from __future__ import annotations
from typing import Dict, Optional, Sequence
import torch
import torch.nn as nn
from comfy.image_encoders.dino2 import Dinov2Model
from .camera import CameraDec, CameraEnc
from .dpt import DPT, DualDPT
from .ray_pose import get_extrinsic_from_camray
from .transform import affine_inverse, pose_encoding_to_extri_intri
_HEAD_REGISTRY = {
"dpt": DPT,
"dualdpt": DualDPT,
}
# Backbone presets (mirror the upstream DINOv2 ViT variants).
_BACKBONE_PRESETS = {
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
}
def _build_backbone_config(
backbone_name: str,
*,
alt_start: int,
qknorm_start: int,
rope_start: int,
cat_token: bool,
) -> dict:
if backbone_name not in _BACKBONE_PRESETS:
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
cfg = dict(_BACKBONE_PRESETS[backbone_name])
cfg.update(dict(
layer_norm_eps=1e-6,
patch_size=14,
image_size=518,
# No mask_token in DA3 weights; omit param to avoid load warnings.
use_mask_token=False,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
rope_freq=100.0,
))
return cfg
class DepthAnything3Net(nn.Module):
PATCH_SIZE = 14
def __init__(
self,
# --- Backbone ---
backbone_name: str = "vitl",
out_layers: Sequence[int] = (4, 11, 17, 23),
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
cat_token: bool = False,
# --- Head ---
head_type: str = "dpt", # dpt or dualdpt
head_dim_in: int = 1024,
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
head_features: int = 256,
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
head_use_sky_head: bool = True, # ignored by DualDPT
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
# --- Camera (multi-view) ---
has_cam_enc: bool = False,
has_cam_dec: bool = False,
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
# ComfyUI plumbing
device=None, dtype=None, operations=None,
**_ignored,
):
super().__init__()
head_cls = _HEAD_REGISTRY[head_type.lower()]
self.head_type = head_type.lower()
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
self.has_conf = head_output_dim > 1
self.out_layers = list(out_layers)
backbone_cfg = _build_backbone_config(
backbone_name,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
)
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
head_kwargs = dict(
dim_in=head_dim_in,
patch_size=self.PATCH_SIZE,
output_dim=head_output_dim,
features=head_features,
out_channels=tuple(head_out_channels),
device=device, dtype=dtype, operations=operations,
)
if self.head_type == "dpt":
head_kwargs.update(
use_sky_head=head_use_sky_head,
pos_embed=(False if head_pos_embed is None else head_pos_embed),
)
else: # dualdpt
head_kwargs.update(
pos_embed=(True if head_pos_embed is None else head_pos_embed),
)
self.head = head_cls(**head_kwargs)
# Built only if checkpoint has weights; cam_enc output dim == embed_dim.
embed_dim = backbone_cfg["hidden_size"]
if has_cam_enc:
self.cam_enc = CameraEnc(
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
num_heads=max(1, embed_dim // 64),
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_enc = None
if has_cam_dec:
default_dim = embed_dim * (2 if cat_token else 1)
self.cam_dec = CameraDec(
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_dec = None
self.dtype = dtype
def forward(
self,
image: torch.Tensor,
extrinsics: Optional[torch.Tensor] = None,
intrinsics: Optional[torch.Tensor] = None,
*,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
export_feat_layers: Optional[Sequence[int]] = None,
**_unused,
) -> Dict[str, torch.Tensor]:
"""Run depth and optionally pose prediction."""
if image.ndim == 4:
image = image.unsqueeze(1) # (B, 1, 3, H, W)
assert image.ndim == 5 and image.shape[2] == 3, \
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
B, S, _, H, W = image.shape
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
# Camera-token preparation (multi-view path).
cam_token = None
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
# Toggle aux ray output on/off depending on what the caller asked for.
if isinstance(self.head, DualDPT):
self.head.enable_aux = bool(use_ray_pose)
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
image, self.out_layers, cam_token=cam_token,
ref_view_strategy=ref_view_strategy,
export_feat_layers=export_feat_layers,
)
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
# Pose prediction.
out: Dict[str, torch.Tensor] = {}
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
ray = head_out["ray"]
ray_conf = head_out["ray_conf"]
extr_c2w, focal, pp = get_extrinsic_from_camray(
ray, ray_conf, ray.shape[-3], ray.shape[-2],
)
# Match the upstream output: w2c, drop the homogeneous row.
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
# Build pixel-space intrinsics from the normalised focal/pp output.
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
out["extrinsics"] = extr_w2c
out["intrinsics"] = intr
elif self.cam_dec is not None and S > 1:
# Decode the cam-token of the final out_layer into a pose encoding.
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
pose_enc = self.cam_dec(cam_feat)
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
# Match the upstream output convention: w2c (world->camera), 3x4.
c2w_4x4 = torch.cat([
c2w_3x4,
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
.view(1, 1, 1, 4).expand(B, S, 1, 4),
], dim=-2)
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
out["intrinsics"] = intr
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
# per-image consumer keeps its (B*S, H, W) interface.
for k, v in head_out.items():
if k in ("ray", "ray_conf"):
# Keep multi-view shape for downstream pose work.
out[k] = v
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
out[k] = v.reshape(B * S, *v.shape[2:])
else:
out[k] = v
if export_feat_layers:
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
return out
def _reshape_aux_features(self, aux_feats, H: int, W: int):
"""Reshape (B, S, N, C) aux features into (B, S, h_p, w_p, C)."""
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
out = []
for f in aux_feats:
B, S, N, C = f.shape
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
out.append(f.reshape(B, S, ph, pw, C))
return out

View File

@ -0,0 +1,128 @@
"""Input/output preprocessing helpers for Depth Anything 3."""
from __future__ import annotations
from typing import Tuple
import torch
import comfy.utils
PATCH_SIZE = 14
# ImageNet normalization constants used during DA3 training.
_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])
def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
down = (x // patch) * patch
up = down + patch
return up if abs(up - x) <= abs(x - down) else down
def compute_target_size(orig_h: int, orig_w: int, process_res: int, method: str = "upper_bound_resize") -> Tuple[int, int]:
"""Compute (target_h, target_w) for a single image.
upper_bound_resize: scale longest side to process_res, then round each dim to nearest multiple of 14 (default upstream method).
lower_bound_resize: scale shortest side to process_res, then round."""
if method == "upper_bound_resize":
longest = max(orig_h, orig_w)
scale = process_res / float(longest)
elif method == "lower_bound_resize":
shortest = min(orig_h, orig_w)
scale = process_res / float(shortest)
else:
raise ValueError(f"Unsupported process_res_method: {method}")
new_w = max(1, _round_to_patch(int(round(orig_w * scale))))
new_h = max(1, _round_to_patch(int(round(orig_h * scale))))
return new_h, new_w
def preprocess_image(image: torch.Tensor, process_res: int = 504, method: str = "upper_bound_resize") -> torch.Tensor:
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
B, H, W, _ = image.shape
target_h, target_w = compute_target_size(H, W, process_res, method)
# (B, H, W, 3) -> (B, 3, H, W)
x = image.movedim(-1, 1).contiguous()
if (target_h, target_w) != (H, W):
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale).
# Lanczos in ``common_upscale`` is anti-aliased and produces the
# closest pixel-wise match in a sweep across {bilinear, bicubic,
# area, lanczos, bislerp}. Used in both directions for simplicity.
x = comfy.utils.common_upscale(x.float(), target_w, target_h, "lanczos", "disabled",)
x = x.clamp(0.0, 1.0)
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
x = (x - mean) / std
return x
# -----------------------------------------------------------------------------
# Output post-processing (sky-aware clipping for Mono/Metric variants)
# -----------------------------------------------------------------------------
def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
"""Boolean mask: True for non-sky pixels (sky probability < threshold)."""
return sky_prediction < threshold
def apply_sky_aware_clip(depth: torch.Tensor, sky: torch.Tensor, threshold: float = 0.3, quantile: float = 0.99) -> torch.Tensor:
"""Clips sky regions to the 99th percentile of non-sky depth. Returns a new depth tensor."""
non_sky = compute_non_sky_mask(sky, threshold=threshold)
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
return depth.clone()
non_sky_depth = depth[non_sky]
if non_sky_depth.numel() > 100_000:
idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device)
sampled = non_sky_depth[idx]
else:
sampled = non_sky_depth
max_depth = torch.quantile(sampled, quantile)
out = depth.clone()
out[~non_sky] = max_depth
return out
def normalize_depth_v2_style(depth: torch.Tensor, sky: torch.Tensor | None = None, low_quantile: float = 0.01, high_quantile: float = 0.99) -> torch.Tensor:
"""V2-style normalization computes percentile bounds over non-sky pixels (when available), then maps depth into [0, 1] with near = white (1.0)."""
if sky is not None:
mask = compute_non_sky_mask(sky)
if mask.any():
valid = depth[mask]
else:
valid = depth.flatten()
else:
valid = depth.flatten()
if valid.numel() > 100_000:
idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device)
sample = valid[idx]
else:
sample = valid
lo = torch.quantile(sample, low_quantile)
hi = torch.quantile(sample, high_quantile)
rng = (hi - lo).clamp(min=1e-6)
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
# Nearer pixels are brighter (1.0)
norm = 1.0 - norm
if sky is not None:
# Sky pixels become black (far / unknown)
sky_mask = ~compute_non_sky_mask(sky)
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
return norm
def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor:
"""Simple per-frame min/max normalization with near=1.0 convention."""
lo = depth.amin(dim=(-2, -1), keepdim=True)
hi = depth.amax(dim=(-2, -1), keepdim=True)
rng = (hi - lo).clamp(min=1e-6)
return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0)

View File

@ -0,0 +1,272 @@
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3."""
from __future__ import annotations
from typing import Optional, Tuple
import torch
# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops.
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Decompose A = Q @ L with Q orthogonal and L lower-triangular.
Implemented in terms of QR by reversing the columns/rows; the standard
trick from the upstream reference. Inputs A are (3, 3)."""
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device, dtype=A.dtype)
A_tilde = A @ P
# CUDA QR is not implemented for fp16/bf16; upcast just for this call.
Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float())
Q_tilde = Q_tilde.to(A.dtype)
R_tilde = R_tilde.to(A.dtype)
Q = Q_tilde @ P
L = P @ R_tilde @ P
d = torch.diag(L)
sign = torch.sign(d)
Q = Q * sign[None, :] # scale columns of Q
L = L * sign[:, None] # scale rows of L
return Q, L
def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
# -----------------------------------------------------------------------------
# Weighted-LSQ + RANSAC homography (batched)
# -----------------------------------------------------------------------------
def _find_homography_weighted_lsq(src_pts: torch.Tensor, dst_pts: torch.Tensor, confident_weight: torch.Tensor,) -> torch.Tensor:
"""Solve a single H with weighted least-squares (DLT)."""
N = src_pts.shape[0]
if N < 4:
raise ValueError("At least 4 points are required to compute a homography.")
w = confident_weight.sqrt().unsqueeze(1) # (N, 1)
x = src_pts[:, 0:1]
y = src_pts[:, 1:2]
u = dst_pts[:, 0:1]
v = dst_pts[:, 1:2]
zeros = torch.zeros_like(x)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
A = torch.cat([A1, A2], dim=0) # (2N, 9)
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
_, _, Vh = torch.linalg.svd(A.float())
Vh = Vh.to(A.dtype)
H = Vh[-1].reshape(3, 3)
return H / H[-1, -1]
def _find_homography_weighted_lsq_batched(src_pts_batch: torch.Tensor, dst_pts_batch: torch.Tensor, confident_weight_batch: torch.Tensor) -> torch.Tensor:
"""Batched DLT solver. Inputs (B, K, 2) / (B, K); output (B, 3, 3)."""
B, K, _ = src_pts_batch.shape
w = confident_weight_batch.sqrt().unsqueeze(2)
x = src_pts_batch[:, :, 0:1]
y = src_pts_batch[:, :, 1:2]
u = dst_pts_batch[:, :, 0:1]
v = dst_pts_batch[:, :, 1:2]
zeros = torch.zeros_like(x)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
_, _, Vh = torch.linalg.svd(A.float())
Vh = Vh.to(A.dtype)
H = Vh[:, -1].reshape(B, 3, 3)
return H / H[:, 2:3, 2:3]
def _ransac_find_homography_weighted_batched(
src_pts: torch.Tensor, # (B, N, 2)
dst_pts: torch.Tensor, # (B, N, 2)
confident_weight: torch.Tensor, # (B, N)
n_sample: int,
n_iter: int = 100,
reproj_threshold: float = 3.0,
num_sample_for_ransac: int = 8,
max_inlier_num: int = 10000,
rand_sample_iters_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Batched weighted-RANSAC homography estimator. Returns (B, 3, 3) homography matrices."""
B, N, _ = src_pts.shape
assert N >= 4
device = src_pts.device
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1)
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
if rand_sample_iters_idx is None:
rand_sample_iters_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac]
for _ in range(n_iter)],
dim=0,
)
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k)
b_idx = (
torch.arange(B, device=device)
.view(B, 1, 1)
.expand(B, n_iter, num_sample_for_ransac)
)
src_b = src_pts[b_idx, rand_idx]
dst_b = dst_pts[b_idx, rand_idx]
w_b = confident_weight[b_idx, rand_idx]
cB, cN = src_b.shape[:2]
H_batch = _find_homography_weighted_lsq_batched(
src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1),
).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3)
src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2)
proj = torch.bmm(
src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3),
H_batch.reshape(-1, 3, 3).transpose(1, 2),
) # (B*n_iter, N, 3)
proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2)
err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N)
inlier_mask = err < reproj_threshold
score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2)
best_idx = torch.argmax(score, dim=1)
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx]
# Refit with the inlier set (per-batch, since the inlier counts vary).
H_inlier_list = []
for b in range(B):
mask = best_inlier_mask[b]
in_src = src_pts[b][mask]
in_dst = dst_pts[b][mask]
in_w = confident_weight[b][mask]
if in_src.shape[0] < 4:
# Fall back to identity when RANSAC fails to find enough inliers.
H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype))
continue
sorted_w = torch.argsort(in_w, descending=True)
if len(sorted_w) > max_inlier_num:
keep = max(int(len(sorted_w) * 0.95), max_inlier_num)
sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]]
H_inlier_list.append(
_find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w])
)
return torch.stack(H_inlier_list, dim=0)
# -----------------------------------------------------------------------------
# Camera-ray utilities
# -----------------------------------------------------------------------------
def _unproject_identity(num_y: int, num_x: int, B: int, S: int, device, dtype) -> torch.Tensor:
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane."""
dx = 1.0 / num_x
dy = 1.0 / num_y
# Centered camera-space coords directly (skip the K^-1 step since it's
# just a translation by -1 on x and y when K is identity-with-center=1).
y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype)
x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype)
yy, xx = torch.meshgrid(y, x, indexing="ij")
grid = torch.stack((xx, yy), dim=-1) # (h, w, 2)
grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2)
return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
def _camray_to_caminfo(
camray: torch.Tensor, # (B, S, h, w, 6)
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
reproj_threshold: float = 0.2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert per-pixel camera rays to per-view (R, T, focal, principal)."""
if confidence is None:
confidence = torch.ones_like(camray[..., 0])
B, S, h, w, _ = camray.shape
device = camray.device
dtype = camray.dtype
rays_target = camray[..., :3] # (B, S, h, w, 3)
rays_origin = _unproject_identity(h, w, B, S, device, dtype)
# Flatten (B*S, h*w, *) for the RANSAC routine.
rays_target = rays_target.flatten(0, 1).flatten(1, 2)
rays_origin = rays_origin.flatten(0, 1).flatten(1, 2)
weights = confidence.flatten(0, 1).flatten(1, 2).clone()
# Project to 2D in homogeneous form (the upstream calls this "perspective division").
z_thresh = 1e-4
mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh)
weights = torch.where(mask, weights, torch.zeros_like(weights))
src = rays_origin.clone()
dst = rays_target.clone()
src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0])
src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1])
dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0])
dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1])
src = src[..., :2]
dst = dst[..., :2]
N = src.shape[1]
n_iter = 100
sample_ratio = 0.3
num_sample_for_ransac = 8
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
rand_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
dim=0,
)
# Chunk along the view axis to keep peak memory predictable.
chunk = 2
A_list = []
for i in range(0, src.shape[0], chunk):
A = _ransac_find_homography_weighted_batched(
src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk],
n_sample=n_sample, n_iter=n_iter,
num_sample_for_ransac=num_sample_for_ransac,
reproj_threshold=reproj_threshold,
rand_sample_iters_idx=rand_idx,
max_inlier_num=8000,
)
# Flip sign on dets that come out < 0 (so that the QL produces a
# right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so
# do the comparison in fp32.
flip = torch.linalg.det(A.float()) < 0
A = torch.where(flip[:, None, None], -A, A)
A_list.append(A)
A = torch.cat(A_list, dim=0) # (B*S, 3, 3)
R_list, f_list, pp_list = [], [], []
for i in range(A.shape[0]):
R, L = _ql_decomposition(A[i])
L = L / L[2][2]
f_list.append(torch.stack((L[0][0], L[1][1])))
pp_list.append(torch.stack((L[2][0], L[2][1])))
R_list.append(R)
R = torch.stack(R_list).reshape(B, S, 3, 3)
focal = torch.stack(f_list).reshape(B, S, 2)
pp = torch.stack(pp_list).reshape(B, S, 2)
# Translation: confidence-weighted average of camray direction(s).
cf = confidence.flatten(0, 1).flatten(1, 2)
T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1)
T = T / cf.sum(dim=-1, keepdim=True)
T = T.reshape(B, S, 3)
# Match upstream output convention: focal -> 1/focal, pp + 1.
return R, T, 1.0 / focal, pp + 1.0
def get_extrinsic_from_camray(
camray: torch.Tensor, # (B, S, h, w, 6)
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
patch_size_y: int,
patch_size_x: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output."""
if conf.ndim == 5 and conf.shape[-1] == 1:
conf = conf.squeeze(-1)
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4)
homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4)
extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4)
return extr, focal, pp

View File

@ -0,0 +1,87 @@
"""Reference-view selection for the multi-view path of Depth Anything 3."""
from __future__ import annotations
from typing import Literal
import torch
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``.
# Reference selection only runs when there are at least this many views.
THRESH_FOR_REF_SELECTION: int = 3
def select_reference_view(x: torch.Tensor, strategy: RefViewStrategy = "saddle_balanced") -> torch.Tensor:
"""Pick a reference view index per batch element."""
B, S, _, _ = x.shape
if S <= 1:
return torch.zeros(B, dtype=torch.long, device=x.device)
if strategy == "first":
return torch.zeros(B, dtype=torch.long, device=x.device)
if strategy == "middle":
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
# Feature-based strategies: normalised cls/cam token per view.
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C)
if strategy == "saddle_balanced":
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S)
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S)
feat_norm = x[:, :, 0].norm(dim=-1) # (B,S)
feat_var = img_class_feat.var(dim=-1) # (B,S)
def _normalize(metric):
mn = metric.min(dim=1, keepdim=True).values
mx = metric.max(dim=1, keepdim=True).values
return (metric - mn) / (mx - mn + 1e-8)
sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var)
balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs()
return balance.argmin(dim=1)
if strategy == "saddle_sim_range":
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_max = sim_no_diag.max(dim=-1).values
sim_min = sim_no_diag.min(dim=-1).values
return (sim_max - sim_min).argmax(dim=1)
raise ValueError(
f"Unknown reference view selection strategy: {strategy!r}. "
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
)
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
"""Reorder x so the reference view is at position 0 in axis S."""
B, S = x.shape[0], x.shape[1]
if S <= 1:
return x
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
b_idx_exp = b_idx.unsqueeze(1)
reorder = torch.where(
(positions > 0) & (positions <= b_idx_exp),
positions - 1,
positions,
)
reorder[:, 0] = b_idx
batch = torch.arange(B, device=x.device).unsqueeze(1)
return x[batch, reorder]
def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
"""Inverse of reorder_by_reference."""
B, S = x.shape[0], x.shape[1]
if S <= 1:
return x
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
b_idx_exp = b_idx.unsqueeze(1)
restore = torch.where(target_positions < b_idx_exp, target_positions + 1, target_positions)
restore = torch.scatter(restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp))
batch = torch.arange(B, device=x.device).unsqueeze(1)
return x[batch, restore]

View File

@ -0,0 +1,160 @@
"""Geometry / camera transform helpers for Depth Anything 3."""
from __future__ import annotations
from typing import Tuple
import torch
import torch.nn.functional as F
# -----------------------------------------------------------------------------
# Affine 4x4 helpers
# -----------------------------------------------------------------------------
def as_homogeneous(ext: torch.Tensor) -> torch.Tensor:
"""Promote (...,3,4) extrinsics to (...,4,4) homogeneous form. No-op when the input is already ``(...,4,4)``."""
if ext.shape[-2:] == (4, 4):
return ext
if ext.shape[-2:] == (3, 4):
ones = torch.zeros_like(ext[..., :1, :4])
ones[..., 0, 3] = 1.0
return torch.cat([ext, ones], dim=-2)
raise ValueError(f"Invalid affine shape: {ext.shape}")
def affine_inverse(A: torch.Tensor) -> torch.Tensor:
"""Inverse of an affine matrix ``[R|T; 0 0 0 1]``."""
R = A[..., :3, :3]
T = A[..., :3, 3:]
P = A[..., 3:, :]
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
# -----------------------------------------------------------------------------
# Quaternion <-> rotation matrix (xyzw / scalar-last)
# -----------------------------------------------------------------------------
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""sqrt(max(0, x)) with a zero subgradient where x == 0."""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""Force the real part of a unit quaternion (xyzw) to be non-negative."""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert quaternions (xyzw) to (...,3,3) rotation matrices."""
i, j, k, r = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
"""Convert (...,3,3) rotation matrices to quaternions (xyzw)."""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
batch_dim + (4,)
)
# Reorder rijk -> xyzw (i.e. ijkr).
out = out[..., [1, 2, 3, 0]]
return standardize_quaternion(out)
# -----------------------------------------------------------------------------
# Pose-encoding <-> extrinsics + intrinsics
# -----------------------------------------------------------------------------
def extri_intri_to_pose_encoding(extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor:
"""Pack (extr, intr, image_size) into the 9-D pose-encoding vector.
extrinsics: camera-to-world (c2w) (B,S,4,4) matrices,
intrinsics: pixel-space (B,S,3,3) matrices,
image_size_hw: is a (H, W) pair.
"""
R = extrinsics[..., :3, :3]
T = extrinsics[..., :3, 3]
quat = mat_to_quat(R)
H, W = image_size_hw
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
def pose_encoding_to_extri_intri(pose_encoding: torch.Tensor, image_size_hw: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Inverse of extri_intri_to_pose_encoding."""
T = pose_encoding[..., :3]
quat = pose_encoding[..., 3:7]
fov_h = pose_encoding[..., 7]
fov_w = pose_encoding[..., 8]
# Normalize to unit quaternion. CameraDec outputs raw values; a near-zero
# quaternion causes two_s = 2/norm² → inf in quat_to_mat → NaN extrinsics.
quat = quat / quat.norm(dim=-1, keepdim=True).clamp(min=1e-6)
R = quat_to_mat(quat)
extrinsics = torch.cat([R, T[..., None]], dim=-1)
H, W = image_size_hw
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device, dtype=pose_encoding.dtype)
intrinsics[..., 0, 0] = fx
intrinsics[..., 1, 1] = fy
intrinsics[..., 0, 2] = W / 2
intrinsics[..., 1, 2] = H / 2
intrinsics[..., 2, 2] = 1.0
return extrinsics, intrinsics

View File

@ -106,11 +106,11 @@ class Ideogram4EmbedScalar(nn.Module):
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
def forward(self, x):
def forward(self, x, dtype):
x = x.to(torch.float32)
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
emb = _sinusoidal_embedding(scaled, self.dim)
emb = emb.to(self.mlp_in.weight.dtype)
emb = emb.to(dtype)
emb = F.silu(self.mlp_in(emb))
return self.mlp_out(emb)
@ -161,7 +161,7 @@ class Ideogram4Transformer(nn.Module):
x = x * output_image_mask
h = self.input_proj(x) * output_image_mask
t_cond = self.t_embedding(t)
t_cond = self.t_embedding(t, dtype=x.dtype)
if t.dim() == 1:
t_cond = t_cond.unsqueeze(1)
adaln_input = F.silu(self.adaln_proj(t_cond))
@ -174,7 +174,7 @@ class Ideogram4Transformer(nn.Module):
llm = self.llm_cond_proj(llm) * text_mask
h[:, :L_text] = h[:, :L_text] + llm
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long))
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype)
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
freqs_cis = precompute_freqs_cis(
@ -235,7 +235,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
L_text = context_chunk.shape[1]
L = L_text + L_img
@ -268,7 +268,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)

290
comfy/ldm/krea2/model.py Normal file
View File

@ -0,0 +1,290 @@
"""Krea 2 (K2) — single-stream MMDiT.
Text tokens produced by a Qwen3-VL-4B 12-layer ``txtfusion`` adapter and patchified image tokens are
concatenated into one sequence and run through ``layers`` shared transformer blocks with
AdaLN-single modulation, GQA + per-head QK-norm + sigmoid-gated attention, SwiGLU MLP, and 3-axis RoPE.
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import comfy.model_management
import comfy.patcher_extension
import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND, timestep_embedding
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention_masked
class RMSNorm(nn.Module):
"""RMSNorm with the reference ``(1 + scale)`` weight convention (scale stored zero-centered)."""
def __init__(self, features: int, eps: float = 1e-5, device=None, dtype=None, operations=None):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.empty(features, device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
weight = comfy.model_management.cast_to(self.scale, dtype=torch.float32, device=x.device) + 1.0
return F.rms_norm(x.float(), (x.shape[-1],), weight=weight, eps=self.eps).to(dtype)
class QKNorm(nn.Module):
def __init__(self, dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.qnorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
self.knorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
def forward(self, q, k):
return self.qnorm(q), self.knorm(k)
class SwiGLU(nn.Module):
def __init__(self, features: int, multiplier: int, bias: bool = False, multiple: int = 128,
device=None, dtype=None, operations=None):
super().__init__()
mlpdim = int(2 * features / 3) * multiplier
mlpdim = multiple * ((mlpdim + multiple - 1) // multiple)
self.gate = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
self.up = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
self.down = operations.Linear(mlpdim, features, bias=bias, device=device, dtype=dtype)
def forward(self, x):
return self.down(F.silu(self.gate(x)).mul_(self.up(x)))
class Attention(nn.Module):
def __init__(self, dim: int, heads: int, kvheads: Optional[int] = None, bias: bool = False,
device=None, dtype=None, operations=None):
super().__init__()
self.heads = heads
self.kvheads = kvheads if kvheads is not None else heads
self.headdim = dim // self.heads
self.wq = operations.Linear(dim, self.headdim * self.heads, bias=bias, device=device, dtype=dtype)
self.wk = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
self.wv = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
self.gate = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
self.qknorm = QKNorm(self.headdim, device=device, dtype=dtype, operations=operations)
self.wo = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
def forward(self, x, freqs=None, mask=None, transformer_options={}):
q, k, v, gate = self.wq(x), self.wk(x), self.wv(x), self.gate(x)
q = rearrange(q, "B L (H D) -> B H L D", H=self.heads)
k = rearrange(k, "B L (H D) -> B H L D", H=self.kvheads)
v = rearrange(v, "B L (H D) -> B H L D", H=self.kvheads)
q, k = self.qknorm(q, k)
if freqs is not None:
q, k = apply_rope(q, k, freqs)
if self.kvheads != self.heads:
rep = self.heads // self.kvheads
k = k.repeat_interleave(rep, dim=1)
v = v.repeat_interleave(rep, dim=1)
out = optimized_attention_masked(q, k, v, self.heads, mask=mask, skip_reshape=True,
transformer_options=transformer_options)
return self.wo(out * F.sigmoid(gate))
class SimpleModulation(nn.Module):
def __init__(self, dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.lin = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
def forward(self, vec):
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device).unsqueeze(0)
scale, shift = out.chunk(2, dim=1)
return scale, shift
class DoubleSharedModulation(nn.Module):
def __init__(self, dim: int, device=None, dtype=None, operations=None):
super().__init__()
self.lin = nn.Parameter(torch.empty(6 * dim, device=device, dtype=dtype))
def forward(self, vec):
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device)
return out.chunk(6, dim=-1)
class TextFusionBlock(nn.Module):
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
super().__init__()
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
def forward(self, x, mask=None, transformer_options={}):
x = x + self.attn(self.prenorm(x), mask=mask, transformer_options=transformer_options)
x = x + self.mlp(self.postnorm(x))
return x
class TextFusionTransformer(nn.Module):
def __init__(self, num_txt_layers, txt_dim, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
super().__init__()
self.layerwise_blocks = nn.ModuleList([
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
for _ in range(2)
])
self.projector = operations.Linear(num_txt_layers, 1, bias=False, device=device, dtype=dtype)
self.refiner_blocks = nn.ModuleList([
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
for _ in range(2)
])
def forward(self, x, mask=None, transformer_options={}):
b, l, n, d = x.shape
x = x.reshape(b * l, n, d)
for block in self.layerwise_blocks:
x = block(x.contiguous(), mask=None, transformer_options=transformer_options)
x = rearrange(x, "(b l) n d -> b l d n", b=b, l=l)
x = self.projector(x).squeeze(-1)
for block in self.refiner_blocks:
x = block(x, mask=mask, transformer_options=transformer_options)
return x
class SingleStreamBlock(nn.Module):
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
super().__init__()
self.mod = DoubleSharedModulation(features, device=device, dtype=dtype, operations=operations)
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
def forward(self, x, vec, freqs, mask=None, transformer_options={}):
prescale, preshift, pregate, postscale, postshift, postgate = self.mod(vec)
x = x + pregate * self.attn((1 + prescale) * self.prenorm(x) + preshift, freqs, mask, transformer_options=transformer_options)
x = x + postgate * self.mlp((1 + postscale) * self.postnorm(x) + postshift)
return x
class LastLayer(nn.Module):
def __init__(self, features, patch, channels, device=None, dtype=None, operations=None):
super().__init__()
self.norm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
self.linear = operations.Linear(features, patch * patch * channels, bias=True, device=device, dtype=dtype)
self.modulation = SimpleModulation(features, device=device, dtype=dtype, operations=operations)
def forward(self, x, tvec):
scale, shift = self.modulation(tvec)
x = (1 + scale) * self.norm(x) + shift
return self.linear(x)
class SingleStreamDiT(nn.Module):
def __init__(self, features=6144, tdim=256, txtdim=2560, heads=48, kvheads=12, multiplier=4,
layers=28, patch=2, channels=16, bias=False, theta=1e3, txtlayers=12,
txtheads=20, txtkvheads=20, image_model=None,
device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
self.patch = patch
self.channels = channels
self.tdim = tdim
self.heads = heads
self.txtdim = txtdim
self.txtlayers = txtlayers
headdim = features // heads
axes = [headdim - 12 * (headdim // 16), 6 * (headdim // 16), 6 * (headdim // 16)]
assert sum(axes) == headdim, f"axes {axes} sum != headdim {headdim}"
self.pe_embedder = EmbedND(dim=headdim, theta=int(theta), axes_dim=axes)
self.first = operations.Linear(channels * patch ** 2, features, bias=True, device=device, dtype=dtype)
self.blocks = nn.ModuleList([
SingleStreamBlock(features, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
for _ in range(layers)
])
self.tmlp = nn.Sequential(
operations.Linear(tdim, features, device=device, dtype=dtype),
nn.GELU(approximate="tanh"),
operations.Linear(features, features, device=device, dtype=dtype),
)
self.txtfusion = TextFusionTransformer(txtlayers, txtdim, txtheads, multiplier, bias, txtkvheads,
device=device, dtype=dtype, operations=operations)
self.txtmlp = nn.Sequential(
RMSNorm(txtdim, device=device, dtype=dtype, operations=operations),
operations.Linear(txtdim, features, device=device, dtype=dtype),
nn.GELU(approximate="tanh"),
operations.Linear(features, features, device=device, dtype=dtype),
)
self.last = LastLayer(features, patch, channels, device=device, dtype=dtype, operations=operations)
self.tproj = nn.Sequential(
nn.GELU(approximate="tanh"),
operations.Linear(features, features * 6, device=device, dtype=dtype),
)
def forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
temporal = x.ndim == 5
if temporal:
b5, c5, t5, h5, w5 = x.shape
x = x.reshape(b5 * t5, c5, h5, w5)
bs, c, H_orig, W_orig = x.shape
patch = self.patch
# Pad the latent up to a multiple of patch (as Flux/Lumina/QwenImage do); crop back at the end.
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch, patch))
H, W = x.shape[-2], x.shape[-1]
h_, w_ = H // patch, W // patch
# context arrives as (B, seq, txtlayers*txtdim); reshape to (B, txtlayers, seq, txtdim).
context = self._unpack_context(context)
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch, pw=patch)
img = self.first(img)
t = self.tmlp(timestep_embedding(timesteps, self.tdim).unsqueeze(1).to(img.dtype))
tvec = self.tproj(t)
context = self.txtfusion(context, mask=None, transformer_options=transformer_options)
context = self.txtmlp(context)
txtlen, imglen = context.shape[1], img.shape[1]
combined = torch.cat((context, img), dim=1)
# Position ids: text at 0, image at (0, h_idx, w_idx).
device = combined.device
txtpos = torch.zeros(bs, txtlen, 3, device=device, dtype=torch.float32)
imgids = torch.zeros(h_, w_, 3, device=device, dtype=torch.float32)
imgids[..., 1] = torch.arange(h_, device=device, dtype=torch.float32)[:, None]
imgids[..., 2] = torch.arange(w_, device=device, dtype=torch.float32)[None, :]
imgpos = imgids.reshape(1, h_ * w_, 3).repeat(bs, 1, 1)
pos = torch.cat((txtpos, imgpos), dim=1)
freqs = self.pe_embedder(pos)
for block in self.blocks:
combined = block(combined, tvec, freqs, None, transformer_options=transformer_options)
final = self.last(combined, t)
out = final[:, txtlen:txtlen + imglen, :]
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
h=h_, w=w_, ph=patch, pw=patch, c=self.channels)
out = out[:, :, :H_orig, :W_orig] # crop padding back off
if temporal:
out = out.reshape(b5, t5, self.channels, H_orig, W_orig).movedim(1, 2)
return out
def _unpack_context(self, context):
# context: (B, seq, txtlayers*txtdim) -> (B, seq, txtlayers, txtdim).
b, seq, fused = context.shape
if fused != self.txtlayers * self.txtdim:
raise ValueError(
f"Krea2 expects conditioning with {self.txtlayers}x{self.txtdim}={self.txtlayers * self.txtdim} "
f"features (a {self.txtlayers}-layer Qwen3-VL stack) but got {fused}. "
f"Load the text encoder with CLIPLoader type 'krea2'."
)
return context.reshape(b, seq, self.txtlayers, self.txtdim)

View File

@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel):
)
grid_mask = None
if keyframe_idxs is not None:
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
additional_args.update({ "orig_patchified_shape": list(x.shape)})
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel):
x = x * (1 + scale) + shift
x = self.proj_out(x)
if keyframe_idxs is not None:
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
grid_mask = kwargs["grid_mask"]
orig_patchified_shape = kwargs["orig_patchified_shape"]
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)

View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat
from comfy.ldm.lightricks.model import Timesteps
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.model_management
import comfy.ldm.common_dit
@ -17,13 +18,11 @@ def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape).to(dtype=x.dtype)
return apply_rope1(x, freqs_cis)
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return F.silu(x) * y
return F.silu(x, inplace=True).mul_(y)
class TimestepEmbedding(nn.Module):

View File

@ -51,6 +51,18 @@ class FeedForward(nn.Module):
return hidden_states
# Addin this back because Nunchaku custom nodes rely on it, see comment here:
# https://github.com/Comfy-Org/ComfyUI/pull/14178#issuecomment-4640475161
# TODO: Eventually remove this once we natively support SVDQuants
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__()

View File

@ -8,7 +8,7 @@ from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.math import apply_rope1, rope
import comfy.ldm.common_dit
import comfy.model_management
import comfy.patcher_extension
@ -570,6 +570,14 @@ class WanModel(torch.nn.Module):
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
x = torch.concat((full_ref, x), dim=1)
# In-context reference (Bernini)
context_latents = kwargs.get("context_latents", None)
main_len = x.shape[1]
if context_latents is not None:
for lat in context_latents:
cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2)
x = torch.cat([x, cl], dim=1)
# context
context = self.text_embedding(context)
@ -599,6 +607,9 @@ class WanModel(torch.nn.Module):
# head
x = self.head(x, e)
if context_latents is not None:
x = x[:, :main_len]
if full_ref is not None:
x = x[:, full_ref.shape[1]:]
@ -606,7 +617,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@ -638,6 +649,13 @@ class WanModel(torch.nn.Module):
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
# In-context reference: a non-zero source_id composes an extra rotation into the spatial rope
if source_id:
d = self.dim // self.num_heads
pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32)
id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype)
freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot)
return freqs
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
@ -661,6 +679,15 @@ class WanModel(torch.nn.Module):
t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
# In-context reference: one rope block per stream, each with it's own source_id (1, 2, ...) to distinguish from the target (id 0).
context_latents = kwargs.get("context_latents", None)
if context_latents is not None:
context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents]
for i, lat in enumerate(context_latents):
freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1)
kwargs = {**kwargs, "context_latents": context_latents}
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
@ -1631,13 +1658,15 @@ class SCAILWanModel(WanModel):
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
if reference_latent is not None:
x = torch.cat((reference_latent, x), dim=2)
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if ref_mask_latents is not None: # SCAIL-2 additive mask stream (one identity mask frame per reference, then video)
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
@ -1645,6 +1674,8 @@ class SCAILWanModel(WanModel):
scail_pose_seq_len = 0
if pose_latents is not None:
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
if sam_latents is not None: # SCAIL-2 additive mask stream
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
scail_x = scail_x.flatten(2).transpose(1, 2)
scail_pose_seq_len = scail_x.shape[1]
x = torch.cat([x, scail_x], dim=1)
@ -1695,16 +1726,44 @@ class SCAILWanModel(WanModel):
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
# reference_latent may stack several frames: the last is the primary reference adjacent to the video, the earlier frames are additional references.
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
if ref_mask_flag is not None and not bool(ref_mask_flag):
REF_ROPE_H = 120.0
POSE_ROPE_W = 120.0
main_t_patches = t - ref_t_patches
video_t_start = max(ref_t_patches - 1, 0)
parts = []
if ref_t_patches > 0:
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
if main_t_patches > 0:
parts.append(super().rope_encode(main_t_patches, h, w, t_start=video_t_start, device=device, dtype=dtype, transformer_options=transformer_options))
if pose_latents is not None:
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
h_scale = h / H_pose
w_scale = w / W_pose
h_shift = (h_scale - 1) / 2
w_shift = (w_scale - 1) / 2
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=video_t_start, device=device, dtype=dtype, transformer_options=pose_tf))
return torch.cat(parts, dim=1)
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
if pose_latents is None:
return main_freqs
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
@ -1719,12 +1778,16 @@ class SCAILWanModel(WanModel):
return torch.cat([main_freqs, pose_freqs], dim=1)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if pose_latents is not None:
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
if ref_mask_latents is not None: # SCAIL-2
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
if sam_latents is not None: # SCAIL-2
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
t_len = t
if time_dim_concat is not None:
@ -1737,5 +1800,15 @@ class SCAILWanModel(WanModel):
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
t_len += reference_latent.shape[2]
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
class SCAIL2WanModel(SCAILWanModel):
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)

View File

@ -326,6 +326,17 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
if isinstance(model, comfy.model_base.Krea2):
diffusers_keys = comfy.utils.krea2_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = k[:-len(".weight")]
key_map["diffusion_model.{}".format(key_lora)] = to
key_map["transformer.{}".format(key_lora)] = to
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
key_map[key_lora] = to
if isinstance(model, comfy.model_base.Lumina2):
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
@ -357,6 +368,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["transformer.{}".format(key_lora)] = k
if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map

View File

@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch
import logging
import comfy.ldm.lightricks.av_model
import comfy.ldm.lightricks.symmetric_patchifier
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC
@ -54,8 +55,10 @@ import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.boogu.model
import comfy.ldm.qwen_image.model
import comfy.ldm.ideogram4.model
import comfy.ldm.krea2.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
@ -65,6 +68,7 @@ import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.ldm.hidream_o1.model
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
import comfy.ldm.depth_anything_3.model
import comfy.model_management
import comfy.patcher_extension
@ -1202,6 +1206,127 @@ class LTXAV(BaseModel):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return latent_image
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
result = [primary_indices]
if len(latent_shapes) < 2:
return result
video_total = latent_shapes[0][dim]
for i in range(1, len(latent_shapes)):
mod_total = latent_shapes[i][dim]
# Map each primary index to its proportional range of modality indices and
# concatenate in order. Preserves wrapped/strided geometry so the modality
# attends to the same temporal regions as the primary window.
mod_indices = []
seen = set()
for v_idx in primary_indices:
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
if a_end <= a_start:
a_end = a_start + 1
for a in range(a_start, a_end):
if a not in seen:
seen.add(a)
mod_indices.append(a)
result.append(mod_indices)
return result
@staticmethod
def _get_guide_entries(conds):
for cond_list in conds:
if cond_list is None:
continue
for cond_dict in cond_list:
model_conds = cond_dict.get('model_conds', {})
entries = model_conds.get('guide_attention_entries')
if entries is not None and hasattr(entries, 'cond') and entries.cond:
return entries.cond
return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# Audio denoise mask — slice using audio modality window
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
audio_window = window.modality_windows.get(1)
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
return cond_value._copy_with(sliced)
# Video denoise mask — split into video + guide portions, slice each
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
cond_tensor = cond_value.cond
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
if guide_count > 0:
T_video = x_in.size(window.dim)
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
suffix_indices = window.guide_frames_indices
if suffix_indices:
idx = tuple([slice(None)] * window.dim + [suffix_indices])
sliced_guide = guide_mask[idx].to(device)
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
else:
return cond_value._copy_with(sliced_video)
# Keyframe indices — regenerate pixel coords for window, select guide positions
if cond_key == "keyframe_idxs":
kf_local_pos = window.guide_kf_local_positions
if not kf_local_pos:
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
H, W = x_in.shape[3], x_in.shape[4]
window_len = len(window.index_list)
# account for causal_window_fix anchor in coord space size
anchor_idx = getattr(window, 'causal_anchor_index', None)
if anchor_idx is not None and anchor_idx >= 0:
window_len += 1
patchifier = self.diffusion_model.patchifier
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
scale_factors = self.diffusion_model.vae_scale_factors
pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords(
latent_coords,
scale_factors,
causal_fix=self.diffusion_model.causal_temporal_positioning)
tokens = []
for pos in kf_local_pos:
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
pixel_coords = pixel_coords[:, :, tokens, :]
# Adjust spatial end positions for dilated (downscaled) guides.
# Each guide entry may have a different downscale factor; expand the
# per-entry factor to cover all tokens belonging to that entry.
downscale_factors = window.guide_downscale_factors
overlap_info = window.guide_overlap_info
if downscale_factors:
per_token_factor = []
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
per_token_factor.extend([dsf] * (overlap_count * H * W))
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
).view(1, -1, 1, 1)
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
B = cond_value.cond.shape[0]
if B > 1:
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
return cond_value._copy_with(pixel_coords)
# Guide attention entries — adjust per-guide counts based on window overlap
if cond_key == "guide_attention_entries":
overlap_info = window.guide_overlap_info
H, W = x_in.shape[3], x_in.shape[4]
new_entries = []
for entry_idx, overlap_count in overlap_info:
e = cond_value.cond[entry_idx]
new_entries.append({**e,
"pre_filter_count": overlap_count * H * W,
"latent_shape": [overlap_count, H, W]})
return cond_value._copy_with(new_entries)
return None
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@ -1518,8 +1643,26 @@ class WAN21(BaseModel):
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
# In-context reference conditioning (Bernini)
context_latents = kwargs.get("context_latents", None)
if context_latents is not None:
out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents])
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# In-context cond slicing (Bernini)
if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list):
dim = window.dim
out = []
for lat in cond_value.cond:
if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]:
out.append(window.get_tensor(lat, device, dim=dim, retain_index_list=retain_index_list))
else:
out.append(lat.to(device))
return cond_value._copy_with(out)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21_CausalAR(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@ -1728,10 +1871,14 @@ class WAN21_SCAIL(WAN21):
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
ref_latent = self.process_latent_in(reference_latents[-1])
ref_mask = torch.ones_like(ref_latent[:, :4])
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
# SCAIL-2 multi-reference: reference_latents[0] is the primary ref, [1:] are additional
# references. Stack as [additional..., primary] so the primary stays adjacent to the video.
ordered = list(reference_latents[1:]) + list(reference_latents[:1])
stacked = []
for lat in ordered:
lat = self.process_latent_in(lat)
stacked.append(torch.cat([lat, torch.ones_like(lat[:, :4])], dim=1))
out['reference_latent'] = comfy.conds.CONDRegular(torch.cat(stacked, dim=2))
pose_latents = kwargs.get("pose_video_latent", None)
if pose_latents is not None:
@ -1754,6 +1901,99 @@ class WAN21_SCAIL(WAN21):
return out
class WAN21_SCAIL2(WAN21_SCAIL):
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel)
self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents")
self.memory_usage_shape_process = {
"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
"sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
}
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
if driving_mask_28ch is not None:
out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous())
# ref_mask_28ch holds one identity mask per stacked reference frame (additional refs first, then the primary ref), followed by zeros over the video frames.
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
if ref_mask_28ch is not None:
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous())
ref_mask_flag = kwargs.get("ref_mask_flag", None)
if ref_mask_flag is not None:
out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag)
return out
def extra_conds_shapes(self, **kwargs):
out = super().extra_conds_shapes(**kwargs)
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
if driving_mask_28ch is not None:
s = driving_mask_28ch.shape
out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]]
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
if ref_mask_28ch is not None:
s = ref_mask_28ch.shape
out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]]
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key in ("sam_latents", "pose_latents"):
# Return sliced view omitting retain_index_list
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0)
if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
# The ref mask is N leading ref frames padded with frames of zeros, so just grab the first frames for all windows
full_ref_mask = cond_value.cond
video_frame_count = x_in.shape[2]
ref_frame_count = full_ref_mask.shape[2] - video_frame_count
if ref_frame_count < 1:
return None
window_length = len(window.index_list)
# Account for the causal anchor frame if it exists
anchor_index = getattr(window, "causal_anchor_index", None)
if anchor_index is not None and anchor_index >= 0:
window_length += 1
window_ref_mask = full_ref_mask[:, :, :window_length + ref_frame_count].to(device)
return cond_value._copy_with(window_ref_mask)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
def concat_cond(self, **kwargs):
# The 4 extra channels are the history_mask (1 at clean-anchor frames).
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
if extra_channels != 4:
return super().concat_cond(**kwargs)
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
return torch.zeros_like(noise)[:, :4]
device = kwargs["device"]
if mask.shape[1] != 4:
mask = torch.mean(mask, dim=1, keepdim=True)
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
if mask.shape[1] == 1:
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return mask
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
# Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image.
return latent_image
class WAN22_WanDancer(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)
@ -1987,6 +2227,11 @@ class Omnigen2(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class Boogu(Omnigen2):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(Omnigen2, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.boogu.model.BooguTransformer2DModel)
self.memory_usage_factor_conds = ("ref_latents",)
class QwenImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
@ -2034,6 +2279,17 @@ class Ideogram4(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Krea2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.krea2.model.SingleStreamDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class HunyuanImage21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@ -2227,6 +2483,12 @@ class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
class DepthAnything3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net)
class ErnieImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)

View File

@ -630,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "humo"
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "animate"
elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail2"
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail"
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:
@ -759,6 +761,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if '{}double_stream_layers.0.img_instruct_attn.processor.img_to_q.weight'.format(key_prefix) in state_dict_keys: # Boogu-Image (OmniGen2 derivative + dual-stream stage)
dit_config = {}
dit_config["image_model"] = "boogu"
dit_config["hidden_size"] = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[0]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}single_stream_layers.'.format(key_prefix) + '{}.')
dit_config["num_double_stream_layers"] = count_blocks(state_dict_keys, '{}double_stream_layers.'.format(key_prefix) + '{}.')
dit_config["num_refiner_layers"] = count_blocks(state_dict_keys, '{}noise_refiner.'.format(key_prefix) + '{}.')
dit_config["instruction_feat_dim"] = state_dict['{}time_caption_embed.caption_embedder.0.weight'.format(key_prefix)].shape[0]
return dit_config
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
dit_config = {}
dit_config["image_model"] = "omnigen2"
@ -822,6 +834,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
return dit_config
if '{}txtfusion.projector.weight'.format(key_prefix) in state_dict_keys: # Krea 2 (K2)
dit_config = {}
dit_config["image_model"] = "krea2"
head_dim = 128
first_w = state_dict['{}first.weight'.format(key_prefix)] # (features, channels*patch^2)
dit_config["features"] = first_w.shape[0]
dit_config["channels"] = first_w.shape[1] // (2 * 2) # patch=2
dit_config["patch"] = 2
dit_config["layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
dit_config["heads"] = state_dict['{}blocks.0.attn.wq.weight'.format(key_prefix)].shape[0] // head_dim
dit_config["kvheads"] = state_dict['{}blocks.0.attn.wk.weight'.format(key_prefix)].shape[0] // head_dim
dit_config["txtlayers"] = state_dict['{}txtfusion.projector.weight'.format(key_prefix)].shape[1]
dit_config["txtdim"] = state_dict['{}txtfusion.layerwise_blocks.0.prenorm.scale'.format(key_prefix)].shape[0]
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
@ -860,6 +887,95 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config
# Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py)
if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "DepthAnything3"
patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)]
embed_dim = patch_w.shape[0]
depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.')
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
if backbone_name is None:
return None
dit_config["backbone_name"] = backbone_name
# Detect DA3 extensions on top of vanilla DINOv2.
has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys
# qk-norm shows up as `attention.q_norm.weight` on enabled blocks.
qknorm_indices = [
i for i in range(depth)
if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys
]
qknorm_start = qknorm_indices[0] if qknorm_indices else -1
# The DA3 main-series configs always set alt_start == qknorm_start == rope_start.
# cat_token=True is implied by the presence of camera_token.
if has_camera_token:
dit_config["alt_start"] = qknorm_start
dit_config["rope_start"] = qknorm_start
dit_config["qknorm_start"] = qknorm_start
dit_config["cat_token"] = True
else:
dit_config["alt_start"] = -1
dit_config["rope_start"] = -1
dit_config["qknorm_start"] = -1
dit_config["cat_token"] = False
# Detect head type and config.
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
dit_config["head_out_channels"] = [
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
for i in range(4)
]
if has_aux:
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
dit_config["head_type"] = "dualdpt"
dit_config["head_output_dim"] = 2
dit_config["head_use_sky_head"] = False
else:
dit_config["head_type"] = "dpt"
dit_config["head_output_dim"] = state_dict[
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
].shape[0]
dit_config["head_use_sky_head"] = (
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
)
# out_layers: hard-coded per upstream YAML config (depth-aware default).
if depth >= 24:
# vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT).
if has_aux:
dit_config["out_layers"] = [11, 15, 19, 23]
else:
dit_config["out_layers"] = [4, 11, 17, 23]
else:
# vits/vitb: 12 blocks
dit_config["out_layers"] = [5, 7, 9, 11]
# Camera encoder/decoder presence (multi-view + pose path).
has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys
has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys
dit_config["has_cam_enc"] = has_cam_enc
dit_config["has_cam_dec"] = has_cam_dec
if has_cam_enc:
cam_enc_w = state_dict.get(
'{}cam_enc.pose_branch.fc2.weight'.format(key_prefix)
)
if cam_enc_w is not None:
dit_config["cam_dim_out"] = cam_enc_w.shape[0]
if has_cam_dec:
cam_dec_w = state_dict.get(
'{}cam_dec.fc_t.weight'.format(key_prefix)
)
if cam_dec_w is not None:
dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1]
return dit_config
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
dit_config = {}
dit_config["image_model"] = "ernie"

View File

@ -534,8 +534,10 @@ try:
except:
pass
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
torch.backends.cudnn.benchmark = True
def set_cudnn_benchmark():
if torch.cuda.is_available() and torch.backends.cudnn.is_available():
torch.backends.cudnn.benchmark = PerformanceFeature.AutoTune in args.fast
try:
if torch_version_numeric >= (2, 5):
@ -641,6 +643,8 @@ def free_pins(size, evict_active=False):
return freed_total
def ensure_pin_budget(size, evict_active=False):
if args.high_ram:
return True
if args.fast_disk:
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
else:
@ -651,8 +655,7 @@ def ensure_pin_budget(size, evict_active=False):
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=True):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
def free_registrations(shortfall, evict_active=True):
if MAX_PINNED_MEMORY <= 0:
return False
if shortfall <= 0:
@ -674,6 +677,9 @@ def ensure_pin_registerable(size, evict_active=True):
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
def ensure_pin_registerable(size, evict_active=True):
return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active)
class LoadedModel:
def __init__(self, model: ModelPatcher):
self._set_model(model)
@ -956,8 +962,6 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
do_gc = False
reset_cast_buffers()
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():
@ -1494,6 +1498,8 @@ if not args.disable_pinned_memory:
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
def pinned_hostbuf_size(size):
if args.high_ram:
return max(0, int(size * 2))
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
def discard_cuda_async_error():

View File

@ -379,10 +379,11 @@ class ModelPatcher:
def get_clone_model_override(self):
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
def clone(self, disable_dynamic=False, model_override=None):
def clone(self, disable_dynamic=False, model_override=None, force_deepcopy=False):
class_ = self.__class__
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
if self.is_dynamic() and disable_dynamic or force_deepcopy:
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
if model_override is None:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")

View File

@ -180,7 +180,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if pin is not None:
cast_maybe_lowvram_patch([pin], dest, offload_stream)
return
if signature is None:
if signature is None or args.high_ram:
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset)
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
@ -299,21 +299,21 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
if hasattr(s, "_v") and comfy.model_management.is_device_cpu(device):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return format_return((weight, bias, (None, None, None)), offloadable)
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return format_return((weight, bias, (None, None, None)), offloadable)
elif hasattr(s, "_v") and s.weight.device != device:
prefetched = hasattr(s, "_prefetch")
offload_stream = None
offload_device = None

View File

@ -89,13 +89,26 @@ def pin_memory(module, subset="weights", size=None):
not comfy.model_management.ensure_pin_registerable(registerable_size)):
return _steal_pin(module, stack, buckets, size, priority)
extended = False
try:
hostbuf.extend(size=size)
hostbuf.extend(size=size, register=False)
extended = True
pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
pin.untyped_storage()._comfy_hostbuf = hostbuf
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
comfy.model_management.free_registrations(size)
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
del pin
hostbuf.truncate(offset, do_unregister=False)
return _steal_pin(module, stack, buckets, size, priority)
except RuntimeError:
if extended:
hostbuf.truncate(offset, do_unregister=False)
return _steal_pin(module, stack, buckets, size, priority)
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
module._pin = pin
stack.append((module, offset))
module._pin_registered = True
module._pin_stack_index = len(stack) - 1

View File

@ -58,6 +58,7 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.krea2
import comfy.text_encoders.ideogram4
import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
@ -67,6 +68,8 @@ import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.text_encoders.qwen3vl
import comfy.text_encoders.boogu
import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
@ -1302,6 +1305,8 @@ class CLIPType(Enum):
LENS = 28
PIXELDIT = 29
IDEOGRAM4 = 30
BOOGU = 31
KREA2 = 32
@ -1355,6 +1360,8 @@ class TEModel(Enum):
GEMMA_4_31B = 31
T5_GEMMA = 32
GPT_OSS_20B = 33
QWEN3VL_4B = 34
QWEN3VL_8B = 35
def detect_te_model(sd):
@ -1416,6 +1423,8 @@ def detect_te_model(sd):
if weight.shape[0] == 5120:
return TEModel.QWEN35_27B
return TEModel.QWEN35_2B
if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL
return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B
if "model.layers.0.post_attention_layernorm.weight" in sd:
weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@ -1614,6 +1623,28 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
elif te_model in (TEModel.QWEN3VL_4B, TEModel.QWEN3VL_8B):
if clip_type == CLIPType.IDEOGRAM4 and te_model == TEModel.QWEN3VL_8B: # Ideogram4 reuses the full Qwen3-VL-8B (13-layer tap for conditioning + multimodal generate).
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.ideogram4.te_qwen3vl(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Qwen3VLTokenizer
elif clip_type == CLIPType.BOOGU and te_model == TEModel.QWEN3VL_8B: # Boogu-Image: full Qwen3-VL-8B, last hidden state, no-think template.
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.boogu.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.boogu.BooguTokenizer
elif clip_type == CLIPType.KREA2 and te_model == TEModel.QWEN3VL_4B: # Krea2: full Qwen3-VL-4B (12-layer tap for conditioning + multimodal generate).
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
clip_target.clip = comfy.text_encoders.krea2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.krea2.Krea2Tokenizer
elif clip_type in (CLIPType.FLUX, CLIPType.FLUX2): # Flux2 Klein reuses the Qwen3-VL LM (3-layer tap -> 12288); visual unused.
klein_model_type = "qwen3_8b" if te_model == TEModel.QWEN3VL_8B else "qwen3_4b"
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type=klein_model_type)
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B if te_model == TEModel.QWEN3VL_8B else comfy.text_encoders.flux.KleinTokenizer
else:
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model]
clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type)
clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type)
elif te_model == TEModel.QWEN3_06B:
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer

View File

@ -25,6 +25,8 @@ import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.ideogram4
import comfy.text_encoders.boogu
import comfy.text_encoders.krea2
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
@ -1450,6 +1452,17 @@ class WAN21_SCAIL(WAN21_T2V):
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
return out
class WAN21_SCAIL2(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "scail2",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device)
return out
class WAN22_WanDancer(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1747,6 +1760,27 @@ class Omnigen2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
class Boogu(Omnigen2):
unet_config = {
"image_model": "boogu",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.16,
}
memory_usage_factor = 2.15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Boogu(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.boogu.BooguTokenizer, comfy.text_encoders.boogu.te(**hunyuan_detect))
class Ideogram4(supported_models_base.BASE):
unet_config = {
"image_model": "ideogram4",
@ -1785,6 +1819,35 @@ class Ideogram4(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect))
class Krea2(supported_models_base.BASE):
unet_config = {
"image_model": "krea2",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.15,
}
memory_usage_factor = 2.2
latent_format = latent_formats.Wan21
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Krea2(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.krea2.Krea2Tokenizer, comfy.text_encoders.krea2.te(**hunyuan_detect))
class QwenImage(supported_models_base.BASE):
unet_config = {
"image_model": "qwen_image",
@ -2046,6 +2109,23 @@ class RT_DETR_v4(supported_models_base.BASE):
return None
class DepthAnything3(supported_models_base.BASE):
unet_config = {
"image_model": "DepthAnything3",
}
# Mono path: no num_heads / num_head_channels needed.
unet_extra_config = {}
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
return model_base.DepthAnything3(self, device=device)
def clip_target(self, state_dict={}):
return None
class ErnieImage(supported_models_base.BASE):
unet_config = {
"image_model": "ernie",
@ -2260,6 +2340,7 @@ models = [
WAN22_Animate,
WAN21_FlowRVS,
WAN21_SCAIL,
WAN21_SCAIL2,
WAN22_WanDancer,
Hunyuan3Dv2mini,
Hunyuan3Dv2,
@ -2272,8 +2353,10 @@ models = [
ACEStep,
ACEStep15,
Omnigen2,
Boogu,
QwenImage,
Ideogram4,
Krea2,
Flux2,
Lens,
Kandinsky5Image,
@ -2287,4 +2370,5 @@ models = [
CogVideoX_I2V,
CogVideoX_T2V,
SVD_img2vid,
DepthAnything3,
]

View File

@ -0,0 +1,58 @@
"""Boogu-Image text encoder: full Qwen3-VL-8B, last hidden state (4096-dim).
Boogu uses the final hidden state of Qwen3-VL as the per-token instruction feature
(num_instruction_feature_layers=1, reduce_type=mean -> just the last layer).
The model itself is the standard Qwen3-VL TE, only the chat template differs
(a fixed system prompt and no <think> block).
"""
import comfy.text_encoders.qwen3vl
from comfy import sd1_clip
# System prompts from the reference pipeline (pipeline_boogu.py).
# T2I (non-empty instruction, no image) uses the helpful-assistant prompt
# everything else (the CFG negative / "drop" condition, and any image case) uses the TI2I "describe" prompt.
BOOGU_T2I_SYSTEM = "You are a helpful assistant that generates high-quality images based on user instructions. The instructions are as follows."
BOOGU_DROP_SYSTEM = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
class BooguTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b")
# apply_chat_template without add_generation_prompt
self.llama_template = "<|im_start|>system\n" + BOOGU_T2I_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n"
self.llama_template_images = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n"
# Reference SYSTEM_PROMPT_DROP: used for the empty negative/uncond instruction.
self.llama_template_drop = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
if llama_template is None and len(images) == 0 and text.strip() == "":
llama_template = self.llama_template_drop
# Boogu conditions on the no-think template; thinking=True drops the empty <think> block qwen3vl adds by default.
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
class BooguQwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"):
super().__init__(device=device, dtype=dtype, attention_mask=attention_mask, model_options=model_options, model_type=model_type)
# apply the final RMSNorm to the tapped last layer
self.layer_norm_hidden_state = True
class BooguTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
clip_model = lambda **kw: BooguQwen3VLClipModel(**kw, model_type="qwen3vl_8b")
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=clip_model, model_options=model_options)
def te(dtype_llama=None, llama_quantization_metadata=None):
class BooguTEModel_(BooguTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return BooguTEModel_

View File

@ -9,6 +9,7 @@ import os
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
import comfy.text_encoders.qwen3vl
from comfy import sd1_clip
# Reference taps outputs of layers (0,3,...,35); comfy captures layer inputs, offset by +1.
@ -32,7 +33,9 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
if text.startswith('<|im_start|>'):
llama_text = text
elif llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
@ -75,3 +78,43 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Ideogram4TEModel_
# Full Qwen3-VL-8B variant with vision
class Ideogram4Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=IDEOGRAM4_TAP_LAYERS, layer_idx=None, dtype=dtype,
attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_8b")
class Ideogram4Qwen3VLTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Ideogram4Qwen3VLClipModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096), ascending layer order.
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13 = 53248).
return out, pooled, extra
class Ideogram4Qwen3VLTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b")
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
# Ideogram 4 conditions on the no-think template; default thinking=True drops the empty think block qwen3vl adds.
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
def te_qwen3vl(dtype_llama=None, llama_quantization_metadata=None):
class Ideogram4Qwen3VLTEModel_(Ideogram4Qwen3VLTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Ideogram4Qwen3VLTEModel_

View File

@ -0,0 +1,84 @@
"""Krea 2 (K2) text encoder: Qwen3-VL-4B, 12-layer tap.
K2 conditions on a stack of hidden states from 12 layers of Qwen3-VL-4B
(reference taps ``hidden_states[2,5,8,...,35]``), kept as a ``(B, 12, seq, 2560)`` tensor and
consumed by the DiT's internal ``txtfusion`` adapter. Comfy carries conditioning as a 3D tensor,
so the 12-layer stack is flattened to ``(B, seq, 12*2560)`` here and unpacked inside the model.
"""
import numbers
import torch
import comfy.text_encoders.qwen3vl
from comfy import sd1_clip
# tap k == hidden_states[k] (no offset).
KREA2_TAP_LAYERS = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35]
# Identical system template to Qwen-Image; Krea2 strips the system+user-opening prefix.
KREA2_TEMPLATE = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
class Krea2Tokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_4b")
self.llama_template = KREA2_TEMPLATE # conditioning template; image text-gen uses qwen3vl's default image template.
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
# Krea2 conditions on the no-think template; thinking=True drops the empty <think> block qwen3vl adds.
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
class Krea2Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=KREA2_TAP_LAYERS, layer_idx=None, dtype=dtype,
attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_4b")
class Krea2TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3vl_4b", clip_model=Krea2Qwen3VLClipModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled, extra = super().encode_token_weights(token_weight_pairs) # out: (B, 12, seq, 2560)
tok_pairs = token_weight_pairs["qwen3vl_4b"][0]
# Strip the system + user-opening prefix
count_im_start = 0
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
if elem == 151644 and count_im_start < 2:
template_end = i
count_im_start += 1
if out.shape[2] > (template_end + 3):
if tok_pairs[template_end + 1][0] == 872: # "user"
if tok_pairs[template_end + 2][0] == 198: # "\n"
template_end += 3
out = out[:, :, template_end:]
b, n, seq, h = out.shape
# Flatten the 12-layer axis into the feature dim: (B, seq, 12*2560). Unpacked in the model.
out = out.permute(0, 2, 1, 3).reshape(b, seq, n * h)
if "attention_mask" in extra:
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
extra.pop("attention_mask")
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
class Krea2TEModel_(Krea2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Krea2TEModel_

View File

@ -251,6 +251,19 @@ class Qwen3_8BConfig:
lm_head: bool = True
stop_tokens = [151643, 151645]
@dataclass
class Qwen3VL_8BConfig(Qwen3_8BConfig):
max_position_embeddings: int = 262144
rope_theta: float = 5000000.0
rope_dims = [24, 20, 20]
interleaved_mrope = True
@dataclass
class Qwen3VL_4BConfig(Qwen3VL_8BConfig):
hidden_size: int = 2560
intermediate_size: int = 9728
lm_head: bool = False # 4B ties word embeddings
@dataclass
class Ovis25_2BConfig:
vocab_size: int = 151936
@ -703,7 +716,8 @@ class Llama2_(nn.Module):
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True,
dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None,deepstack_embeds=None, visual_pos_masks=None):
if embeds is not None:
x = embeds
else:
@ -767,6 +781,10 @@ class Llama2_(nn.Module):
if current_kv is not None:
next_key_values.append(current_kv)
# DeepStack: add per-layer visual features into the first len() decoder layers at image positions (Qwen3-VL)
if deepstack_embeds is not None and i < len(deepstack_embeds):
x[visual_pos_masks] = x[visual_pos_masks] + deepstack_embeds[i].to(x)
if i == intermediate_output:
intermediate = x.clone()
@ -860,7 +878,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None, position_ids=None, deepstack_embeds=None, visual_pos_masks=None):
device = embeds.device
if stop_tokens is None:
@ -884,10 +902,18 @@ class BaseGenerate:
generated_token_ids = []
pbar = comfy.utils.ProgressBar(max_length)
# MRoPE: prefill uses explicit 3D position_ids, decode continues from the last position
next_pos = int(position_ids[:, -1].max()) + 1 if position_ids is not None else None
# Generation loop
current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
# DeepStack visual features are injected on the prefill only; gemma4's forward lacks these kwargs.
extra = {}
if step == 0 and deepstack_embeds is not None:
extra["deepstack_embeds"] = deepstack_embeds
extra["visual_pos_masks"] = visual_pos_masks
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids, position_ids=position_ids, **extra)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item()
@ -895,6 +921,9 @@ class BaseGenerate:
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
current_input_ids = next_token if initial_input_ids is not None else None
if next_pos is not None: # advance MRoPE position for the next (decode) step
position_ids = torch.tensor([[next_pos]], device=device)
next_pos += 1
pbar.update(1)
if token_id in stop_tokens:

View File

@ -3,7 +3,6 @@ import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass, field
import os
import math
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
@ -563,6 +562,8 @@ class Qwen35VisionModel(nn.Module):
for _ in range(config["depth"])
])
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
self.deepstack_visual_indexes = [] # DeepStack, per-layer visual features (Qwen3-VL)
self.deepstack_merger_list = None
def rot_pos_emb(self, grid_thw):
merge_size = self.spatial_merge_size
@ -664,9 +665,14 @@ class Qwen35VisionModel(nn.Module):
).cumsum(dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
for blk in self.blocks:
deepstack_features = []
for layer_num, blk in enumerate(self.blocks):
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
if self.deepstack_merger_list is not None and layer_num in self.deepstack_visual_indexes:
deepstack_features.append(self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](x))
merged = self.merger(x)
if self.deepstack_merger_list is not None:
return merged, deepstack_features
return merged
# Model Wrapper
@ -690,30 +696,7 @@ class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
return None, None
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
grid = None
position_ids = None
offset = 0
for e in embeds_info:
if e.get("type") == "image":
grid = e.get("extra", None)
start = e.get("index")
if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
if grid is None:
position_ids = None
position_ids = comfy.text_encoders.qwen_vl.qwen2vl_mrope_position_ids(embeds_info, embeds.shape[1], embeds.device)
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):

View File

@ -0,0 +1,193 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Qwen2Tokenizer
from comfy import sd1_clip
import comfy.text_encoders.qwen_vl
from .qwen35 import Qwen35VisionModel
from .llama import BaseLlama, BaseQwen3, BaseGenerate, Llama2_, Qwen3VL_4BConfig, Qwen3VL_8BConfig
QWEN3VL_VISION = {
"qwen3vl_4b": dict(hidden_size=1024, intermediate_size=4096, depth=24, deepstack_visual_indexes=[5, 11, 17]),
"qwen3vl_8b": dict(hidden_size=1152, intermediate_size=4304, depth=27, deepstack_visual_indexes=[8, 16, 24]),
}
QWEN3VL_VISION_COMMON = dict(num_heads=16, patch_size=16, temporal_patch_size=2, in_channels=3,
spatial_merge_size=2, num_position_embeddings=2304)
QWEN3VL_CONFIGS = {"qwen3vl_4b": Qwen3VL_4BConfig, "qwen3vl_8b": Qwen3VL_8BConfig}
class Qwen3VLDeepstackMerger(nn.Module):
# DeepStack merger: postshuffle LayerNorm (applied after spatial merge), unlike the main merger.
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
super().__init__()
self.merge_dim = hidden_size * (spatial_merge_size ** 2)
self.norm = ops.LayerNorm(self.merge_dim, eps=1e-6, device=device, dtype=dtype)
self.linear_fc1 = ops.Linear(self.merge_dim, self.merge_dim, device=device, dtype=dtype)
self.linear_fc2 = ops.Linear(self.merge_dim, out_hidden_size, device=device, dtype=dtype)
def forward(self, x):
x = self.norm(x.view(-1, self.merge_dim))
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
class Qwen3VLVisionModel(Qwen35VisionModel):
# Qwen3.5 vision + DeepStack
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__(config, device=device, dtype=dtype, ops=ops)
self.deepstack_visual_indexes = config["deepstack_visual_indexes"]
self.deepstack_merger_list = nn.ModuleList([
Qwen3VLDeepstackMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
for _ in self.deepstack_visual_indexes
])
class Qwen3VL(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
model_type = "qwen3vl_8b"
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = QWEN3VL_CONFIGS[self.model_type](**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
vision_config = {**QWEN3VL_VISION_COMMON, **QWEN3VL_VISION[self.model_type], "out_hidden_size": config.hidden_size}
self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
# Qwen3-VL normalizes to [-1, 1] (mean/std 0.5), unlike Qwen2.5-VL's CLIP normalization.
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])
merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid)
return merged, {"grid": grid, "deepstack": deepstack}
return None, None
def build_image_inputs(self, embeds, embeds_info):
# Returns (position_ids, visual_pos_masks, deepstack) for the prompt
images = sorted([e for e in embeds_info if e.get("type") == "image"], key=lambda e: e["index"])
if len(images) == 0:
return None, None, None
device = embeds.device
seq = embeds.shape[1]
position_ids = comfy.text_encoders.qwen_vl.qwen2vl_mrope_position_ids(embeds_info, seq, device)
# DeepStack: mask of image positions + per-vision-layer features to inject there.
visual_pos_masks = torch.zeros((1, seq), dtype=torch.bool, device=device)
deepstack = None
for e in images:
start = e["index"]
end = e["size"] + start
visual_pos_masks[0, start:end] = True
ds = e["extra"]["deepstack"]
if deepstack is None:
deepstack = [d for d in ds]
else:
deepstack = [torch.cat([deepstack[i], ds[i]], dim=0) for i in range(len(ds))]
return position_ids, visual_pos_masks, deepstack
def _make_qwen3vl_model(model_type):
class Qwen3VL_(Qwen3VL):
pass
Qwen3VL_.model_type = model_type
return Qwen3VL_
class Qwen3VLClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False,
model_class=_make_qwen3vl_model(model_type), enable_attention_masks=attention_mask,
return_attention_masks=attention_mask, model_options=model_options)
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
if isinstance(tokens, dict):
tokens = next(iter(tokens.values()))
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
position_ids, visual_pos_masks, deepstack = self.transformer.build_image_inputs(embeds, embeds_info)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed,
presence_penalty=presence_penalty, position_ids=position_ids,
visual_pos_masks=visual_pos_masks, deepstack_embeds=deepstack)
class Qwen3VLTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen3vl_8b"):
clip_model = lambda **kw: Qwen3VLClipModel(**kw, model_type=model_type)
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
class Qwen3VLSDTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=4096, embedding_key="qwen3vl_8b"):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class Qwen3VLTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen3vl_8b"):
embedding_size = 2560 if model_type == "qwen3vl_4b" else 4096
tokenizer = lambda *a, **kw: Qwen3VLSDTokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
image = kwargs.get("image", None)
if image is not None and len(images) == 0:
images = [image[i:i + 1] for i in range(image.shape[0])]
skip_template = text.startswith('<|im_start|>')
if prevent_empty_text and text == '':
text = ' '
if skip_template:
llama_text = text
else:
if llama_template is not None:
template = llama_template
elif len(images) == 0:
template = self.llama_template
else:
template = self.llama_template_images
if len(images) > 1:
vision_block = "<|vision_start|><|image_pad|><|vision_end|>"
template = template.replace(vision_block, vision_block * len(images), 1)
llama_text = template.format(text)
if not thinking: # Qwen3 convention: empty think block suppresses reasoning
llama_text += "<think>\n\n</think>\n\n"
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
key_name = next(iter(tokens))
embed_count = 0
for r in tokens[key_name]:
for i in range(len(r)):
if r[i][0] == 151655: # <|image_pad|>
if len(images) > embed_count:
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return tokens
def tokenizer(model_type="qwen3vl_8b"):
class Qwen3VLTokenizer_(Qwen3VLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
return Qwen3VLTokenizer_
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3vl_8b"):
class Qwen3VLTEModel_(Qwen3VLTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
return Qwen3VLTEModel_

View File

@ -88,6 +88,32 @@ def process_qwen2vl_images(
return flatten_patches, image_grid_thw
def qwen2vl_mrope_position_ids(embeds_info, seq_len, device):
# (3, seq_len) T/H/W MRoPE position ids: text runs sequentially, each image span gets its grid positions.
# Returns None when there are no image embeds. `extra` is the image grid_thw, or a dict carrying it under "grid".
position_ids = None
offset = 0
for e in embeds_info:
if e.get("type") == "image":
extra = e.get("extra", None)
grid = extra["grid"] if isinstance(extra, dict) else extra
start = e.get("index")
if position_ids is None:
position_ids = torch.zeros((3, seq_len), device=device)
position_ids[:, :start] = torch.arange(0, start, device=device)
end = e.get("size") + start
len_max = int(grid.max()) // 2
start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (seq_len - end) + offset, device=device)
position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
max_d = int(grid[0][2]) // 2
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
offset += len_max - (end - start)
return position_ids
class VisionPatchEmbed(nn.Module):
def __init__(
self,

View File

@ -818,6 +818,44 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def krea2_to_diffusers(mmdit_config, output_prefix=""):
n_layers = mmdit_config.get("layers", 0)
n_txt_layerwise = 2 # TextFusionTransformer hardcodes 2 layerwise + 2 refiner blocks
n_txt_refiner = 2
key_map = {}
def add_block(prefix_to, prefix_from):
block_map = {
"attn.to_q": "attn.wq", "attn.to_k": "attn.wk", "attn.to_v": "attn.wv",
"attn.to_gate": "attn.gate", "attn.to_out.0": "attn.wo",
"attn.to_out": "attn.wo", # some tools drop the ".0" on to_out
"ff.gate": "mlp.gate", "ff.up": "mlp.up", "ff.down": "mlp.down",
}
for d, c in block_map.items():
key_map["{}.{}.weight".format(prefix_to, d)] = "{}{}.{}.weight".format(output_prefix, prefix_from, c)
for i in range(n_layers):
add_block("transformer_blocks.{}".format(i), "blocks.{}".format(i))
for i in range(n_txt_layerwise):
add_block("text_fusion.layerwise_blocks.{}".format(i), "txtfusion.layerwise_blocks.{}".format(i))
for i in range(n_txt_refiner):
add_block("text_fusion.refiner_blocks.{}".format(i), "txtfusion.refiner_blocks.{}".format(i))
MAP_BASIC = [
("img_in", "first"),
("time_embed.linear_1", "tmlp.0"),
("time_embed.linear_2", "tmlp.2"),
("time_mod_proj", "tproj.1"),
("txt_in.linear_1", "txtmlp.1"),
("txt_in.linear_2", "txtmlp.3"),
("text_fusion.projector", "txtfusion.projector"),
("final_layer.linear", "last.linear"),
]
for d, c in MAP_BASIC:
key_map["{}.weight".format(d)] = "{}{}.weight".format(output_prefix, c)
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)

View File

@ -25,6 +25,11 @@ CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = {
"default": False,
"description": "Show the sign-in button in the frontend even when not signed in",
},
"enable_telemetry": {
"type": "bool",
"default": False,
"description": "Signal the frontend that telemetry collection is enabled",
},
}

View File

@ -27,10 +27,13 @@ class VideoInput(ABC):
path: Union[str, IO[bytes]],
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None
metadata: Optional[dict] = None,
bit_depth: int | None = None,
):
"""
Abstract method to save the video input to a file.
bit_depth selects the encoded bit depth; None keeps the video's native depth.
"""
pass
@ -83,6 +86,14 @@ class VideoInput(ABC):
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_bit_depth(self) -> int:
"""
Returns the bit depth of the video (e.g. 8 or 10).
Default implementation returns 8; subclasses report their real depth.
"""
return 8
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.

View File

@ -52,6 +52,12 @@ def get_open_write_kwargs(
return open_kwargs
def video_stream_bit_depth(stream) -> int:
if stream is None or stream.format is None or not stream.format.components:
return 8
return max(component.bits for component in stream.format.components)
class VideoFromFile(VideoInput):
"""
Class representing video input from a file.
@ -97,6 +103,13 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_bit_depth(self) -> int:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode="r") as container:
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
return video_stream_bit_depth(video_stream)
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
@ -257,6 +270,7 @@ class VideoFromFile(VideoInput):
image_format = 'gbrpf32le'
process_image_format = lambda a: a
align_graph = None
audio = None
streams = [video_stream]
@ -310,7 +324,28 @@ class VideoFromFile(VideoInput):
checked_alpha = True
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
# Fix non-deterministic video decode when the video width is not a multiple of 32
# For non-yuvj pixel formats: most H.264/H.265 video and static images (e.g. lossy WebP via LoadImage)
# Pad both axes to a multiple of 32 and smear the border so the alignment padding never bleeds into the cropped edges
if image_format in ('gbrpf32le', 'gbrapf32le') and frame.width % 32 != 0:
if align_graph is None:
pad_w = ((frame.width + 31) // 32) * 32
pad_h = ((frame.height + 31) // 32) * 32
g = av.filter.Graph()
g_src = g.add_buffer(width=frame.width, height=frame.height,
format=frame.format.name, time_base=video_stream.time_base)
g_pad = g.add('pad', f'{pad_w}:{pad_h}:0:0')
g_fill = g.add('fillborders', f'left=0:right={pad_w - frame.width}:top=0:bottom={pad_h - frame.height}:mode=smear')
g_sink = g.add('buffersink')
g_src.link_to(g_pad)
g_pad.link_to(g_fill)
g_fill.link_to(g_sink)
g.configure()
align_graph = (g, g_src, g_sink)
align_graph[1].push(frame)
img = np.ascontiguousarray(align_graph[2].pull().to_ndarray(format=image_format)[:frame.height, :frame.width])
else:
img = frame.to_ndarray(format=image_format)
if frame.rotation != 0:
k = int(round(frame.rotation // 90))
img = np.rot90(img, k=k, axes=(0, 1)).copy()
@ -377,25 +412,32 @@ class VideoFromFile(VideoInput):
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
bit_depth: int | None = None,
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
with av.open(self.__file, mode='r') as container:
container_format = container.format.name
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
video_encoding = video_stream.codec.name if video_stream is not None else None
source_bit_depth = video_stream_bit_depth(video_stream)
reuse_streams = True
if format != VideoContainer.AUTO and format not in container_format.split(","):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
reuse_streams = False
if self.__start_time or self.__duration:
reuse_streams = False
if not reuse_streams:
if bit_depth is None:
bit_depth = source_bit_depth
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
path, format=format, codec=codec, metadata=metadata
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
)
streams = container.streams
@ -451,8 +493,10 @@ class VideoFromComponents(VideoInput):
Class representing video input from tensors.
"""
def __init__(self, components: VideoComponents):
def __init__(self, components: VideoComponents, bit_depth: int = 8):
self.__components = components
# Tensor components have no inherent bit depth; this is the depth used when encoding.
self.__bit_depth = bit_depth
def get_components(self) -> VideoComponents:
return VideoComponents(
@ -461,18 +505,26 @@ class VideoFromComponents(VideoInput):
frame_rate=self.__components.frame_rate,
)
def get_bit_depth(self) -> int:
return self.__bit_depth
def save_to(
self,
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
bit_depth: int | None = None,
):
"""Save the video to a file path or BytesIO buffer."""
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
# None means "use the depth this video was created with" (CreateVideo's choice).
if bit_depth is None:
bit_depth = self.__bit_depth
is_10bit = bit_depth >= 10
extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
@ -488,10 +540,11 @@ class VideoFromComponents(VideoInput):
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
# Create a video stream
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
video_stream = output.add_stream('h264', rate=frame_rate)
video_stream.width = self.__components.images.shape[2]
video_stream.height = self.__components.images.shape[1]
video_stream.pix_fmt = 'yuv420p'
video_stream.pix_fmt = pix_fmt
# Create an audio stream
audio_sample_rate = 1
@ -505,9 +558,14 @@ class VideoFromComponents(VideoInput):
# Encode video
for i, frame in enumerate(self.__components.images):
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
if is_10bit:
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
else:
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
frame = frame.reformat(format=pix_fmt)
packet = video_stream.encode(frame)
output.mux(packet)

View File

@ -755,6 +755,18 @@ class File3DKSPLAT(ComfyTypeIO):
Type = File3D
@comfytype(io_type="FILE_3D_SPLAT_ANY")
class File3DSplatAny(ComfyTypeIO):
"""General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat)."""
Type = File3D
@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY")
class File3DPointCloudAny(ComfyTypeIO):
"""General point cloud file type - accepts any supported point cloud container (currently .ply)."""
Type = File3D
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
if TYPE_CHECKING:
@ -1388,7 +1400,8 @@ class V3Data(TypedDict):
class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any,
extra_pnginfo: Any, dynprompt: Any,
auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs):
auth_token_comfy_org: str, api_key_comfy_org: str,
comfy_usage_source: str = None, **kwargs):
self.unique_id = unique_id
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
self.prompt = prompt
@ -1401,6 +1414,8 @@ class HiddenHolder:
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
self.api_key_comfy_org = api_key_comfy_org
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
self.comfy_usage_source = comfy_usage_source
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
def __getattr__(self, key: str):
'''If hidden variable not found, return None.'''
@ -1417,6 +1432,7 @@ class HiddenHolder:
dynprompt=d.get(Hidden.dynprompt, None),
auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None),
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
comfy_usage_source=d.get(Hidden.comfy_usage_source, None),
)
@classmethod
@ -1439,6 +1455,8 @@ class Hidden(str, Enum):
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
api_key_comfy_org = "API_KEY_COMFY_ORG"
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
comfy_usage_source = "COMFY_USAGE_SOURCE"
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
@dataclass
@ -1642,6 +1660,8 @@ class Schema:
self.hidden.append(Hidden.auth_token_comfy_org)
if Hidden.api_key_comfy_org not in self.hidden:
self.hidden.append(Hidden.api_key_comfy_org)
if Hidden.comfy_usage_source not in self.hidden:
self.hidden.append(Hidden.comfy_usage_source)
# if is an output_node, will need prompt and extra_pnginfo
if self.is_output_node:
if Hidden.prompt not in self.hidden:
@ -2336,6 +2356,8 @@ __all__ = [
"File3DSPLAT",
"File3DSPZ",
"File3DKSPLAT",
"File3DSplatAny",
"File3DPointCloudAny",
"Hooks",
"HookKeyframes",
"TimestepsRange",

View File

@ -285,7 +285,7 @@ class AudioSaveHelper:
results = []
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
file = f"{filename_with_batch_num}_{counter:05}.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially

View File

@ -1310,13 +1310,6 @@ class KlingTaskStatus(str, Enum):
failed = 'failed'
class KlingTextToVideoModelName(str, Enum):
kling_v1 = 'kling-v1'
kling_v1_6 = 'kling-v1-6'
kling_v2_1_master = 'kling-v2-1-master'
kling_v2_5_turbo = 'kling-v2-5-turbo'
class KlingVideoGenAspectRatio(str, Enum):
field_16_9 = '16:9'
field_9_16 = '9:16'
@ -5179,7 +5172,7 @@ class KlingText2VideoRequest(BaseModel):
duration: Optional[KlingVideoGenDuration] = '5'
external_task_id: Optional[str] = Field(None, description='Customized Task ID')
mode: Optional[KlingVideoGenMode] = 'std'
model_name: Optional[KlingTextToVideoModelName] = 'kling-v1'
model_name: Optional[str] = 'kling-v1'
negative_prompt: Optional[str] = Field(
None, description='Negative text prompt', max_length=2500
)

View File

@ -43,6 +43,7 @@ class BFLFluxEraseRequest(BaseModel):
"white (255) marks areas to remove, black (0) marks areas to preserve.",
)
dilate_pixels: int = Field(10)
seed: int | None = Field(None)
output_format: str = Field("png")

View File

@ -97,3 +97,28 @@ class BriaRemoveVideoBackgroundResult(BaseModel):
class BriaRemoveVideoBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveVideoBackgroundResult | None = Field(None)
class BriaVideoGreenScreenRequest(BaseModel):
video: str = Field(..., description="Publicly accessible URL of the input video.")
green_shade: str = Field(
default="broadcast_green",
description="Solid chroma-key shade applied behind the foreground "
"(broadcast_green, chroma_green, or blue_screen).",
)
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)
class BriaVideoReplaceBackgroundRequest(BaseModel):
video: str = Field(..., description="Publicly accessible URL of the input (foreground) video.")
background_url: str = Field(
...,
description="Publicly accessible URL of the background image or video to composite behind "
"the foreground. Stretched to the foreground frame; match its aspect ratio for "
"undistorted results.",
)
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)

View File

@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel):
startOffset: GeminiOffset | None = Field(None)
class GeminiThinkingConfig(BaseModel):
includeThoughts: bool | None = Field(None)
thinkingLevel: str = Field(...)
class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
maxOutputTokens: int | None = Field(None, ge=16, le=65536)
seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(None, ge=0.0, le=2.0)
topK: int | None = Field(None, ge=1)
topP: float | None = Field(None, ge=0.0, le=1.0)
thinkingConfig: GeminiThinkingConfig | None = Field(None)
class GeminiImageOutputOptions(BaseModel):
@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel):
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
class GeminiThinkingConfig(BaseModel):
includeThoughts: bool | None = Field(None)
thinkingLevel: str = Field(...)
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: list[str] | None = Field(None)
imageConfig: GeminiImageConfig | None = Field(None)

View File

@ -149,3 +149,59 @@ class MotionControlRequest(BaseModel):
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")
model_name: str = Field(...)
class Kling3TurboSettings(BaseModel):
resolution: str = Field("720p", description="'720p' or '1080p'")
aspect_ratio: str | None = Field(None, description="'16:9'/'9:16'/'1:1'; text-to-video only")
duration: int = Field(5, description="3-15 second")
class Kling3TurboText2VideoRequest(BaseModel):
prompt: str = Field(..., description="<=3072 chars; may use multi-shot 'shot n, m, words; ...'")
settings: Kling3TurboSettings | None = Field(None)
class Kling3TurboContent(BaseModel):
type: str = Field(..., description="'prompt' or 'first_frame'")
text: str | None = Field(None, description="for type=prompt; <=2500 chars")
url: str | None = Field(None, description="for type=first_frame")
class Kling3TurboImage2VideoRequest(BaseModel):
contents: list[Kling3TurboContent] = Field(..., description="prompt + first_frame materials")
settings: Kling3TurboSettings | None = Field(None)
class Kling3TurboCreateData(BaseModel):
id: str | None = Field(None, description="Task ID")
status: str | None = Field(None)
message: str | None = Field(None)
class Kling3TurboCreateResponse(BaseModel):
code: int | None = Field(None)
message: str | None = Field(None)
request_id: str | None = Field(None)
data: Kling3TurboCreateData | None = Field(None)
class Kling3TurboOutput(BaseModel):
type: str | None = Field(None, description="'video', 'image', 'audio', ...")
id: str | None = Field(None)
url: str | None = Field(None)
duration: str | None = Field(None)
class Kling3TurboTaskData(BaseModel):
id: str | None = Field(None)
status: str | None = Field(None, description="submitted | processing | succeeded | failed")
message: str | None = Field(None)
outputs: list[Kling3TurboOutput] | None = Field(None)
class Kling3TurboQueryResponse(BaseModel):
code: int | None = Field(None)
message: str | None = Field(None)
request_id: str | None = Field(None)
data: list[Kling3TurboTaskData] | None = Field(None)

View File

@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, confloat
class LumaIO:
LUMA_REF = "LUMA_REF"
LUMA_CONCEPTS = "LUMA_CONCEPTS"
LUMA_RAY32_KEYFRAME = "LUMA_RAY32_KEYFRAME"
class LumaReference:
@ -20,13 +21,14 @@ class LumaReference:
def create_api_model(self, download_url: str):
return LumaImageRef(url=download_url, weight=self.weight)
class LumaReferenceChain:
def __init__(self, first_ref: LumaReference=None):
def __init__(self, first_ref: LumaReference = None):
self.refs: list[LumaReference] = []
if first_ref:
self.refs.append(first_ref)
def add(self, luma_ref: LumaReference=None):
def add(self, luma_ref: LumaReference = None):
self.refs.append(luma_ref)
def create_api_model(self, download_urls: list[str], max_refs=4):
@ -124,7 +126,7 @@ def get_luma_concepts(include_none=False):
"pull_out",
"aerial",
"crane_up",
"eye_level"
"eye_level",
]
@ -162,8 +164,8 @@ class LumaVideoModelOutputDuration(str, Enum):
class LumaGenerationType(str, Enum):
video = 'video'
image = 'image'
video = "video"
image = "image"
class LumaState(str, Enum):
@ -174,86 +176,109 @@ class LumaState(str, Enum):
class LumaAssets(BaseModel):
video: Optional[str] = Field(None, description='The URL of the video')
image: Optional[str] = Field(None, description='The URL of the image')
progress_video: Optional[str] = Field(None, description='The URL of the progress video')
video: Optional[str] = Field(None, description="The URL of the video")
image: Optional[str] = Field(None, description="The URL of the image")
progress_video: Optional[str] = Field(None, description="The URL of the progress video")
class LumaImageRef(BaseModel):
"""Used for image gen"""
url: str = Field(..., description='The URL of the image reference')
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
url: str = Field(..., description="The URL of the image reference")
weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
class LumaImageReference(BaseModel):
"""Used for video gen"""
type: Optional[str] = Field('image', description='Input type, defaults to image')
url: str = Field(..., description='The URL of the image')
type: Optional[str] = Field("image", description="Input type, defaults to image")
url: str = Field(..., description="The URL of the image")
class LumaModifyImageRef(BaseModel):
url: str = Field(..., description='The URL of the image reference')
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
url: str = Field(..., description="The URL of the image reference")
weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
class LumaCharacterRef(BaseModel):
identity0: LumaImageIdentity = Field(..., description='The image identity object')
identity0: LumaImageIdentity = Field(..., description="The image identity object")
class LumaImageIdentity(BaseModel):
images: list[str] = Field(..., description='The URLs of the image identity')
images: list[str] = Field(..., description="The URLs of the image identity")
class LumaGenerationReference(BaseModel):
type: str = Field('generation', description='Input type, defaults to generation')
id: str = Field(..., description='The ID of the generation')
type: str = Field("generation", description="Input type, defaults to generation")
id: str = Field(..., description="The ID of the generation")
class LumaKeyframes(BaseModel):
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
class LumaConceptObject(BaseModel):
key: str = Field(..., description='Camera Concept name')
key: str = Field(..., description="Camera Concept name")
class LumaImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The prompt of the generation')
model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation')
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation')
image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects')
style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects')
character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object')
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object')
prompt: str = Field(..., description="The prompt of the generation")
model: LumaImageModel = Field(LumaImageModel.photon_1, description="The image model used for the generation")
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9)
image_ref: Optional[list[LumaImageRef]] = Field(None, description="List of image reference objects")
style_ref: Optional[list[LumaImageRef]] = Field(None, description="List of style reference objects")
character_ref: Optional[LumaCharacterRef] = Field(None, description="The image identity object")
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description="The modify image reference object")
class LumaGenerationRequest(BaseModel):
prompt: str = Field(..., description='The prompt of the generation')
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation')
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation')
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation')
resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation')
loop: Optional[bool] = Field(None, description='Whether to loop the video')
keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation')
concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation')
prompt: str = Field(..., description="The prompt of the generation")
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description="The video model used for the generation")
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description="The duration of the generation")
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description="The aspect ratio of the generation")
resolution: Optional[LumaVideoOutputResolution] = Field(None, description="The resolution of the generation")
loop: Optional[bool] = Field(None, description="Whether to loop the video")
keyframes: Optional[LumaKeyframes] = Field(None, description="The keyframes of the generation")
concepts: Optional[list[LumaConceptObject]] = Field(None, description="Camera Concepts to apply to generation")
class LumaGeneration(BaseModel):
id: str = Field(..., description='The ID of the generation')
generation_type: LumaGenerationType = Field(..., description='Generation type, image or video')
state: LumaState = Field(..., description='The state of the generation')
failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation')
created_at: str = Field(..., description='The date and time when the generation was created')
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
model: str = Field(..., description='The model used for the generation')
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
id: str = Field(..., description="The ID of the generation")
generation_type: LumaGenerationType = Field(..., description="Generation type, image or video")
state: LumaState = Field(..., description="The state of the generation")
failure_reason: Optional[str] = Field(None, description="The reason for the state of the generation")
created_at: str = Field(..., description="The date and time when the generation was created")
assets: Optional[LumaAssets] = Field(None, description="The assets of the generation")
model: str = Field(..., description="The model used for the generation")
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(...)
class Luma2ImageRef(BaseModel):
url: str | None = None
data: str | None = None
media_type: str | None = None
generation_id: str | None = Field(None, description="reference a prior generation (extend / source reuse)")
class Luma2VideoEdit(BaseModel):
"""Edit controls for Ray 3.2 ``video_edit`` generations."""
auto_controls: bool | None = Field(None, description="derive a conditioning schedule from the source (recommended)")
strength: str | None = Field(None, description="'adhere_1' .. 'reimagine_3'; constrained by IO.Combo")
class Luma2VideoOptions(BaseModel):
"""Ray 3.2 ``video`` output settings (text / image / keyframe / edit / extend)."""
resolution: str | None = Field(None, description="360p | 540p | 720p | 1080p")
duration: str | None = Field(None, description="5s | 10s")
loop: bool | None = Field(None)
start_frame: Luma2ImageRef | None = Field(None)
end_frame: Luma2ImageRef | None = Field(None)
keyframes: list[Luma2ImageRef] | None = Field(None)
keyframe_indexes: list[int] | None = Field(None)
edit: Luma2VideoEdit | None = Field(None)
class Luma2GenerationRequest(BaseModel):
@ -266,6 +291,7 @@ class Luma2GenerationRequest(BaseModel):
web_search: bool | None = None
image_ref: list[Luma2ImageRef] | None = None
source: Luma2ImageRef | None = None
video: Luma2VideoOptions | None = Field(None)
class Luma2Generation(BaseModel):
@ -277,3 +303,31 @@ class Luma2Generation(BaseModel):
output: list[LumaImageReference] | None = None
failure_reason: str | None = None
failure_code: str | None = None
# --- Ray 3.2 multi-keyframe chain ---
LUMA_KEYFRAME_MODE_FRACTION = "fraction" # value in [0.0, 1.0] of the output video duration
LUMA_KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the output
class LumaRay32KeyframeItem:
"""One guide image anchored at a position on the Ray 3.2 output timeline."""
def __init__(self, image: torch.Tensor, mode: str, value: float):
self.image = image
self.mode = mode # LUMA_KEYFRAME_MODE_FRACTION | LUMA_KEYFRAME_MODE_SECONDS
self.value = value
class LumaRay32KeyframeChain:
def __init__(self):
self.items: list[LumaRay32KeyframeItem] = []
def add(self, item: LumaRay32KeyframeItem) -> None:
self.items.append(item)
def clone(self) -> "LumaRay32KeyframeChain":
c = LumaRay32KeyframeChain()
c.items = list(self.items)
return c

View File

@ -67,15 +67,6 @@ class RunwayImageToVideoResponse(BaseModel):
id: Optional[str] = Field(None, description='Task ID')
class RunwayTaskStatusEnum(str, Enum):
SUCCEEDED = 'SUCCEEDED'
RUNNING = 'RUNNING'
FAILED = 'FAILED'
PENDING = 'PENDING'
CANCELLED = 'CANCELLED'
THROTTLED = 'THROTTLED'
class RunwayTaskStatusResponse(BaseModel):
createdAt: datetime = Field(..., description='Task creation timestamp')
id: str = Field(..., description='Task ID')
@ -86,7 +77,7 @@ class RunwayTaskStatusResponse(BaseModel):
ge=0.0,
le=1.0,
)
status: RunwayTaskStatusEnum
status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED")
class Model4(str, Enum):
@ -125,3 +116,144 @@ class RunwayTextToImageRequest(BaseModel):
class RunwayTextToImageResponse(BaseModel):
id: Optional[str] = Field(None, description='Task ID')
class RunwayAleph2IO:
"""Custom socket types for chaining Aleph2 guidance images."""
KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME"
PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE"
# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to
# choose the request model below. The values match the Aleph2 keyframe union field names.
KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video
KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration
# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`.
PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video
PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration
class RunwayAleph2KeyframeItem:
"""A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``)."""
def __init__(self, image, mode: str, value: float):
self.image = image
self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT
self.value = value
class RunwayAleph2KeyframeChain:
"""An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes."""
def __init__(self):
self.items: list[RunwayAleph2KeyframeItem] = []
def add(self, item: RunwayAleph2KeyframeItem) -> None:
self.items.append(item)
def clone(self) -> "RunwayAleph2KeyframeChain":
c = RunwayAleph2KeyframeChain()
c.items = list(self.items)
return c
class RunwayAleph2PromptImageItem:
"""A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``)."""
def __init__(self, image, mode: str, value: float):
self.image = image
self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION
self.value = value
class RunwayAleph2PromptImageChain:
"""An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes."""
def __init__(self):
self.items: list[RunwayAleph2PromptImageItem] = []
def add(self, item: RunwayAleph2PromptImageItem) -> None:
self.items.append(item)
def clone(self) -> "RunwayAleph2PromptImageChain":
c = RunwayAleph2PromptImageChain()
c.items = list(self.items)
return c
class RunwayAleph2KeyframeSeconds(BaseModel):
seconds: float = Field(
...,
description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.",
ge=0.0,
)
uri: str = Field(...)
class RunwayAleph2KeyframeAt(BaseModel):
at: float = Field(
...,
description="Position as a fraction [0.0, 1.0] of the input video duration.",
ge=0.0,
le=1.0,
)
uri: str = Field(...)
class RunwayAleph2TimestampPosition(BaseModel):
type: str = Field(default="timestamp")
timestampSeconds: float = Field(
...,
description="Absolute timestamp in seconds from the start of the output video.",
ge=0.0,
)
class RunwayAleph2RelativePosition(BaseModel):
type: str = Field(default="position")
positionPercentage: float = Field(
...,
description="Position as a fraction [0.0, 1.0] of the total output video duration.",
ge=0.0,
le=1.0,
)
class RunwayAleph2PromptImage(BaseModel):
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
uri: str = Field(...)
class RunwayAleph2ContentModeration(BaseModel):
publicFigureThreshold: str = Field(
...,
description='When set to "low", the content moderation system is less strict about '
'recognizable public figures. One of "auto" or "low".',
)
class RunwayAleph2Request(BaseModel):
model: str = Field(default="aleph2")
promptText: str = Field(
...,
description="A non-empty string describing what should appear in the output.",
min_length=1,
max_length=1000,
)
videoUri: str = Field(...)
seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295)
contentModeration: RunwayAleph2ContentModeration = Field(...)
keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field(
None,
description="Timed guidance images placed at specific points in the input video. Up to 5.",
)
promptImage: list[RunwayAleph2PromptImage] | None = Field(
None,
description="Up to 5 image keyframes for guiding the edit at specific points in the output video.",
)
class RunwayAleph2Response(BaseModel):
id: str | None = Field(None, description="Task ID")

View File

@ -208,6 +208,10 @@ class TripoMultiviewToModelRequest(BaseModel):
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
class TripoTexturePrompt(BaseModel):
text: str | None = Field(None, description="Text guidance for texture generation")
class TripoTextureModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task")
original_model_task_id: str = Field(..., description="The task ID of the original model")
@ -219,6 +223,11 @@ class TripoTextureModelRequest(BaseModel):
texture_alignment: TripoTextureAlignment | None = Field(
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
)
texture_prompt: TripoTexturePrompt | None = Field(
None,
description="Optional guidance for texturing. Required in practice for imported models, "
"which carry no source image to infer texture from.",
)
class TripoRefineModelRequest(BaseModel):
@ -307,6 +316,17 @@ class TripoP1MultiviewToModelRequest(TripoP1CommonRequest):
orientation: str | None = None
class TripoImportModelRequest(BaseModel):
"""Request for the comfy-api composite import endpoint (/proxy/tripo/v2/openapi/import).
The model file is uploaded to ComfyUI API storage first; the backend downloads it from
`url`, re-uploads it to Tripo's storage and creates the import_model task server-side.
"""
url: str = Field(..., description="ComfyUI API storage download URL of the model file")
format: str = Field(..., description='File format: "glb", "fbx", "obj" or "stl"')
class TripoTaskOutput(BaseModel):
model: str | None = Field(None, description="URL to the model")
base_model: str | None = Field(None, description="URL to the base model")

View File

@ -534,6 +534,15 @@ class FluxEraseNode(IO.ComfyNode):
max=25,
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
],
outputs=[IO.Image.Output()],
hidden=[
@ -553,6 +562,7 @@ class FluxEraseNode(IO.ComfyNode):
image: Input.Image,
mask: Input.Image,
dilate_pixels: int = 10,
seed: int = 0,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=256, min_height=256)
mask = resize_mask_to_image(mask, image)
@ -565,6 +575,7 @@ class FluxEraseNode(IO.ComfyNode):
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask,
dilate_pixels=dilate_pixels,
seed=seed,
),
)

View File

@ -12,6 +12,8 @@ from comfy_api_nodes.apis.bria import (
BriaRemoveVideoBackgroundRequest,
BriaRemoveVideoBackgroundResponse,
BriaStatusResponse,
BriaVideoGreenScreenRequest,
BriaVideoReplaceBackgroundRequest,
InputModerationSettings,
)
from comfy_api_nodes.util import (
@ -287,7 +289,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
),
)
@ -319,6 +321,161 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
class BriaVideoGreenScreen(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaVideoGreenScreen",
display_name="Bria Video Green Screen",
category="partner/video/Bria",
description="Replace a video's background with a solid chroma-key screen using Bria.",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input(
"green_shade",
options=["broadcast_green", "chroma_green", "blue_screen"],
tooltip="Solid chroma-key shade applied behind the foreground: "
"broadcast_green (#00B140), chroma_green (#00FF00), or blue_screen (#0000FF).",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
green_shade: str,
seed: int,
) -> IO.NodeOutput:
validate_video_duration(video, max_duration=60.0)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/green_screen", method="POST"),
data=BriaVideoGreenScreenRequest(
video=await upload_video_to_comfyapi(cls, video),
green_shade=green_shade,
output_container_and_codec="mp4_h264",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
class BriaVideoReplaceBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaVideoReplaceBackground",
display_name="Bria Video Replace Background",
category="partner/video/Bria",
description="Replace a video's background with a supplied image or video using Bria. "
"The output keeps the foreground's resolution and frame rate; a background with a "
"different aspect ratio is stretched to fit, so match it for undistorted results.",
inputs=[
IO.Video.Input("video", tooltip="Foreground video whose background is replaced."),
IO.Image.Input(
"background_image",
optional=True,
tooltip="Background image to composite behind the foreground. "
"Provide either a background image or a background video, not both.",
),
IO.Video.Input(
"background_video",
optional=True,
tooltip="Background video to composite behind the foreground. "
"Provide either a background image or a background video, not both.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
seed: int,
background_image: Input.Image | None = None,
background_video: Input.Video | None = None,
) -> IO.NodeOutput:
if (background_image is None) == (background_video is None):
raise ValueError("Provide either a background image or a background video, not both.")
validate_video_duration(video, max_duration=60.0)
if background_video is not None:
validate_video_duration(background_video, max_duration=60.0)
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
else:
# Bria's replace_background 500s on RGBA, so drop the alpha channel before upload.
background_url = await upload_image_to_comfyapi(
cls, background_image[:, :, :, :3], wait_label="Uploading background"
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
data=BriaVideoReplaceBackgroundRequest(
video=await upload_video_to_comfyapi(cls, video),
background_url=background_url,
output_container_and_codec="mp4_h264",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]:
"""Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask.
@ -376,7 +533,7 @@ class BriaTransparentVideoBackground(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
),
)
@ -416,6 +573,8 @@ class BriaExtension(ComfyExtension):
BriaImageEditNode,
BriaRemoveImageBackground,
BriaRemoveVideoBackground,
BriaVideoGreenScreen,
BriaVideoReplaceBackground,
BriaTransparentVideoBackground,
]

View File

@ -7,6 +7,7 @@ from io import BytesIO
import torch
from typing_extensions import override
from comfy.utils import common_upscale
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS,
@ -131,6 +132,44 @@ def _prepare_seedance_image(image: Input.Image) -> Input.Image:
return image
# Supported output aspect ratios, used to pre-size FLF frames to matching pixel pair to avoid the 1080p stretch jump.
SEEDANCE2_RATIO_WH = {
"16:9": (16, 9),
"4:3": (4, 3),
"1:1": (1, 1),
"3:4": (3, 4),
"9:16": (9, 16),
"21:9": (21, 9),
}
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
"""Exact supported output (width, height) for (resolution, ratio).
The shorter side equals the resolution number (e.g. 1080p 16:9 -> 1920x1080). For ratio
"adaptive" (or any unexpected value) the ratio is derived from the image's own aspect, snapped
to the nearest supported ratio, so the output keeps the frame's orientation.
"""
short = SEEDANCE2_RES_SHORT_SIDE[resolution]
if ratio not in SEEDANCE2_RATIO_WH:
aspect = image.shape[-2] / image.shape[-3] # W / H; tensor is (B, H, W, C)
ratio = min(SEEDANCE2_RATIO_WH, key=lambda k: abs(SEEDANCE2_RATIO_WH[k][0] / SEEDANCE2_RATIO_WH[k][1] - aspect))
rw, rh = SEEDANCE2_RATIO_WH[ratio]
if rw >= rh: # landscape or square: shorter side is the height
out_w, out_h = round(short * rw / rh), short
else: # portrait: shorter side is the width
out_w, out_h = short, round(short * rh / rw)
return out_w - out_w % 2, out_h - out_h % 2
def _resize_to_exact(image: torch.Tensor, width: int, height: int) -> torch.Tensor:
"""Center-crop to the target aspect and resize to exactly width x height (lanczos)."""
samples = image.movedim(-1, 1) # (B, H, W, C) -> (B, C, H, W)
resized = common_upscale(samples, width, height, "lanczos", "center")
return resized.movedim(1, -1)
async def _resolve_reference_assets(
cls: type[IO.ComfyNode],
asset_ids: list[str],
@ -1790,10 +1829,28 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
if last_frame is not None and last_frame_asset_id:
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
if first_frame is not None:
first_frame = _prepare_seedance_image(first_frame)
if last_frame is not None:
last_frame = _prepare_seedance_image(last_frame)
request_ratio = model["ratio"]
if first_frame_asset_id or last_frame_asset_id:
if first_frame is not None:
first_frame = _prepare_seedance_image(first_frame)
if last_frame is not None:
last_frame = _prepare_seedance_image(last_frame)
else:
# The 1080p FLF stretch fix (pre-size frames to a supported pixel pair + submit ratio="adaptive")
# only applies to local image inputs we can resize.
request_ratio = "adaptive"
target_dims: tuple[int, int] | None = None
if first_frame is not None:
validate_image_aspect_ratio(first_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_dimensions(first_frame, min_width=300, min_height=300)
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], first_frame)
first_frame = _resize_to_exact(first_frame, *target_dims)
if last_frame is not None:
validate_image_aspect_ratio(last_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_dimensions(last_frame, min_width=300, min_height=300)
if target_dims is None:
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], last_frame)
last_frame = _resize_to_exact(last_frame, *target_dims)
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
image_assets: dict[str, str] = {}
@ -1844,7 +1901,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
content=content,
generate_audio=model["generate_audio"],
resolution=model["resolution"],
ratio=model["ratio"],
ratio=request_ratio,
duration=model["duration"],
seed=seed,
watermark=watermark,

View File

@ -5,10 +5,9 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
import base64
import os
from enum import Enum
from fnmatch import fnmatch
from io import BytesIO
from typing import Literal
from typing import Any, Literal
import torch
from typing_extensions import override
@ -19,6 +18,7 @@ from comfy_api_nodes.apis.gemini import (
GeminiContent,
GeminiFileData,
GeminiGenerateContentRequest,
GeminiGenerationConfig,
GeminiGenerateContentResponse,
GeminiImageConfig,
GeminiImageGenerateContentRequest,
@ -40,13 +40,18 @@ from comfy_api_nodes.util import (
get_number_of_images,
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
video_to_base64_string,
)
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
GEMINI_URL_INPUT_BUDGET = 10
GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
@ -72,15 +77,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
)
class GeminiImageModel(str, Enum):
"""
Gemini Image Model Names allowed by comfy-api
"""
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
gemini_2_5_flash_image = "gemini-2.5-flash-image"
async def create_image_parts(
cls: type[IO.ComfyNode],
images: Input.Image | list[Input.Image],
@ -237,21 +233,15 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
if not response.modelVersion:
return None
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"):
if response.modelVersion == "gemini-2.5-pro":
input_tokens_price = 1.25
output_text_tokens_price = 10.0
output_image_tokens_price = 0.0
elif response.modelVersion in (
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-flash",
):
elif response.modelVersion == "gemini-2.5-flash":
input_tokens_price = 0.30
output_text_tokens_price = 2.50
output_image_tokens_price = 0.0
elif response.modelVersion in (
"gemini-2.5-flash-image-preview",
"gemini-2.5-flash-image",
):
elif response.modelVersion == "gemini-2.5-flash-image":
input_tokens_price = 0.30
output_text_tokens_price = 2.50
output_image_tokens_price = 30.0
@ -285,6 +275,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
return final_price / 1_000_000.0
def create_video_parts(video_input: Input.Video) -> list[GeminiPart]:
"""Convert a single video input to Gemini API compatible parts (inline MP4/H.264)."""
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.video_mp4,
data=base_64_string,
)
)
]
def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]:
"""Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item)."""
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = Input.Audio(
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
sample_rate=audio_input["sample_rate"],
)
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
container_format="mp3",
codec_name="libmp3lame",
)
audio_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.audio_mp3,
data=audio_bytes,
)
)
)
return audio_parts
def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]:
"""Expand any batched image tensors into individual (H, W, C) frames, preserving order."""
frames: list[torch.Tensor] = []
for img in images:
if len(img.shape) == 4:
frames.extend(img[i] for i in range(img.shape[0]))
else:
frames.append(img)
return frames
def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]:
"""Expand any batched audio inputs into individual single-clip audio inputs, preserving order."""
clips: list[Input.Audio] = []
for audio in audios:
waveform = audio["waveform"]
for i in range(waveform.shape[0]):
clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"]))
return clips
async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart:
"""Upload a single media unit to ComfyAPI storage and return a fileData (URL) part."""
if kind == "image":
url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image")
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url))
if kind == "audio":
url = await upload_audio_to_comfyapi(
cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3"
)
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url))
url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video")
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url))
def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]:
"""Encode a single media unit as an inline base64 part; returns (part, base64_length)."""
if kind == "image":
data = tensor_to_base64_string(payload, mime_type="image/webp")
mime = GeminiMimeType.image_webp
elif kind == "audio":
data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame")
mime = GeminiMimeType.audio_mp3
else:
data = video_to_base64_string(
payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
mime = GeminiMimeType.video_mp4
return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data)
async def build_gemini_media_parts(
cls: type[IO.ComfyNode],
images: list[Input.Image],
audios: list[Input.Audio],
videos: list[Input.Video],
*,
url_budget: int = GEMINI_URL_INPUT_BUDGET,
max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES,
) -> list[GeminiPart]:
"""Build Gemini parts for multimodal inputs (images, audio, video).
fileData URLs are preferred for every media type: the upload is fetched directly by the
model, keeping the request body tiny regardless of media size. The URL budget is shared
across all media and assigned largest-first (video, then audio, then images), so that if it
is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline
payload is capped by `max_inline_bytes`.
"""
units: list[tuple[str, Any]] = (
[("video", v) for v in videos]
+ [("audio", a) for a in _flatten_audio(audios)]
+ [("image", f) for f in _flatten_images(images)]
)
parts: list[GeminiPart] = []
url_used = 0
inline_bytes = 0
for kind, payload in units:
if url_used < url_budget:
parts.append(await _media_url_part(cls, kind, payload))
url_used += 1
continue
part, nbytes = _media_inline_part(kind, payload)
inline_bytes += nbytes
if inline_bytes > max_inline_bytes:
raise ValueError(
f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first "
f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media."
)
parts.append(part)
return parts
class GeminiNode(IO.ComfyNode):
"""
Node to generate text responses from a Gemini model.
@ -315,8 +439,6 @@ class GeminiNode(IO.ComfyNode):
IO.Combo.Input(
"model",
options=[
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-3-pro-preview",
@ -407,58 +529,9 @@ class GeminiNode(IO.ComfyNode):
)
""",
),
is_deprecated=True,
)
@classmethod
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.video_mp4,
data=base_64_string,
)
)
]
@classmethod
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
"""
Convert audio input to Gemini API compatible parts.
Args:
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
Returns:
List of GeminiPart objects containing the encoded audio.
"""
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = Input.Audio(
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
sample_rate=audio_input["sample_rate"],
)
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
container_format="mp3",
codec_name="libmp3lame",
)
audio_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.audio_mp3,
data=audio_bytes,
)
)
)
return audio_parts
@classmethod
async def execute(
cls,
@ -482,9 +555,9 @@ class GeminiNode(IO.ComfyNode):
if images is not None:
parts.extend(await create_image_parts(cls, images))
if audio is not None:
parts.extend(cls.create_audio_parts(audio))
parts.extend(create_audio_parts(audio))
if video is not None:
parts.extend(cls.create_video_parts(video))
parts.extend(create_video_parts(video))
if files is not None:
parts.extend(files)
@ -512,6 +585,210 @@ class GeminiNode(IO.ComfyNode):
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
GEMINI_V2_MODELS: dict[str, str] = {
"Gemini 3.1 Pro": "gemini-3.1-pro-preview",
"Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview",
}
def _gemini_text_model_inputs(thinking_default: str) -> list[Input]:
"""Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls)."""
return [
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 17)],
min=0,
),
tooltip="Optional image(s) to use as context for the model. Up to 16 images.",
),
IO.Autogrow.Input(
"audio",
template=IO.Autogrow.TemplateNames(
IO.Audio.Input("audio"),
names=["audio_1"],
min=0,
),
tooltip="Optional audio clip to use as context for the model.",
),
IO.Autogrow.Input(
"video",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=["video_1"],
min=0,
),
tooltip="Optional video clip to use as context for the model.",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Input Files node.",
),
IO.Combo.Input(
"thinking_level",
options=["LOW", "HIGH"],
default=thinking_default,
tooltip="How hard the model reasons internally before answering. "
"HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=0.95,
min=0.0,
max=1.0,
step=0.01,
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
advanced=True,
),
IO.Int.Input(
"max_output_tokens",
default=32768,
min=16,
max=65536,
tooltip="Maximum tokens to generate, including the model's internal thinking. "
"With thinking_level HIGH, a low value can leave no room for the answer; raise this if "
"responses come back empty or truncated. The model stops early when finished, so a higher "
"cap costs nothing extra for short replies.",
advanced=True,
),
]
class GeminiNodeV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiNodeV2",
display_name="Google Gemini",
category="partner/text/Gemini",
essentials_category="Text Generation",
description="Generate text responses with Google's Gemini models. Provide a text prompt and, "
"optionally, one or more images, audio clips, videos, or files as multimodal context.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model. Include detailed instructions, questions, or context.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")),
IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")),
],
tooltip="The Gemini model used to generate the response.",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "lite") ? {
"type": "list_usd",
"usd": [0.00025, 0.0015],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
} : {
"type": "list_usd",
"usd": [0.002, 0.012],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_id = GEMINI_V2_MODELS[model["model"]]
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
images = [t for t in (model.get("images") or {}).values() if t is not None]
audios = [a for a in (model.get("audio") or {}).values() if a is not None]
videos = [v for v in (model.get("video") or {}).values() if v is not None]
if images or audios or videos:
parts.extend(await build_gemini_media_parts(cls, images, audios, videos))
files = model.get("files")
if files is not None:
parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
data=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role=GeminiRole.user,
parts=parts,
)
],
generationConfig=GeminiGenerationConfig(
temperature=model["temperature"],
topP=model["top_p"],
maxOutputTokens=model["max_output_tokens"],
seed=seed if seed > 0 else None,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
output_text = get_text_from_response(response)
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
class GeminiInputFiles(IO.ComfyNode):
"""
Loads and formats input files for use with the Gemini API.
@ -609,8 +886,7 @@ class GeminiImage(IO.ComfyNode):
),
IO.Combo.Input(
"model",
options=GeminiImageModel,
default=GeminiImageModel.gemini_2_5_flash_image,
options=["gemini-2.5-flash-image"],
tooltip="The Gemini model to use for generating responses.",
),
IO.Int.Input(
@ -1129,6 +1405,26 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
optional=True,
tooltip="Controls randomness in generation. Lower is more focused/deterministic.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=0.95,
min=0.0,
max=1.0,
step=0.01,
optional=True,
tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.",
advanced=True,
),
],
outputs=[
IO.Image.Output(),
@ -1165,6 +1461,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
seed: int,
response_modalities: str,
system_prompt: str = "",
temperature: float = 1.0,
top_p: float = 0.95,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_choice = model["model"]
@ -1204,6 +1502,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
temperature=temperature,
topP=top_p,
),
systemInstruction=gemini_system_prompt,
),
@ -1222,6 +1522,7 @@ class GeminiExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
GeminiNode,
GeminiNodeV2,
GeminiImage,
GeminiImage2,
GeminiNanoBanana2,

View File

@ -30,7 +30,7 @@ from comfy_api_nodes.util import (
_GROK_VIDEO_MODEL_API_IDS = {
"grok-imagine-video-1.5": "grok-imagine-video-1.5-preview",
"grok-imagine-video-1.5": "grok-imagine-video-1.5",
}
@ -521,8 +521,8 @@ class GrokVideoNode(IO.ComfyNode):
),
IO.Combo.Input(
"resolution",
options=["480p", "720p"],
tooltip="The resolution of the output video.",
options=["480p", "720p", "1080p"],
tooltip="The resolution of the output video. 1080p is only available for grok-imagine-video-1.5.",
),
IO.Combo.Input(
"aspect_ratio",
@ -570,11 +570,12 @@ class GrokVideoNode(IO.ComfyNode):
(
$is15 := $contains(widgets.model, "1.5");
$rate := $is15
? (widgets.resolution = "720p" ? 0.2002 : 0.1144)
? (widgets.resolution = "1080p" ? 0.25 : (widgets.resolution = "720p" ? 0.14 : 0.08))
: (widgets.resolution = "720p" ? 0.07 : 0.05);
$imgCost := $is15 ? 0.0143 : 0.002;
$imgCost := $is15 ? 0.01 : 0.002;
$base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base}
$total := inputs.image.connected ? $base + $imgCost : $base;
{"type":"usd","usd": $is15 ? $total * 1.43 : $total}
)
""",
),
@ -593,6 +594,8 @@ class GrokVideoNode(IO.ComfyNode):
) -> IO.NodeOutput:
if image is None and model == "grok-imagine-video-1.5":
raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.")
if resolution == "1080p" and model != "grok-imagine-video-1.5":
raise ValueError(f"1080p resolution is only available for grok-imagine-video-1.5, not '{model}'.")
image_url = None
if image is not None:
if get_number_of_images(image) != 1:

View File

@ -60,6 +60,12 @@ from comfy_api_nodes.apis.kling import (
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
Kling3TurboSettings,
Kling3TurboText2VideoRequest,
Kling3TurboContent,
Kling3TurboImage2VideoRequest,
Kling3TurboCreateResponse,
Kling3TurboQueryResponse,
TaskStatusResponse,
TextToVideoWithAudioRequest,
)
@ -436,7 +442,7 @@ async def execute_text2video(
negative_prompt=negative_prompt if negative_prompt else None,
duration=KlingVideoGenDuration(duration),
mode=KlingVideoGenMode(model_mode),
model_name=KlingVideoGenModelName(model_name),
model_name=model_name,
cfg_scale=cfg_scale,
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
camera_control=camera_control,
@ -2847,6 +2853,67 @@ class MotionControl(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
def build_turbo_shot_prompt(multi_prompt: list[MultiPromptEntry]) -> str:
"""Render storyboard entries into the Turbo multi-shot prompt 'shot n, m, words; ...'."""
return "; ".join(f"shot {i}, {int(e.duration)}, {e.prompt}" for i, e in enumerate(multi_prompt, 1)) + ";"
def _turbo_video_url(response: Kling3TurboQueryResponse) -> str:
"""Extract the result video URL from a /tasks response (data[].outputs[] where type == 'video')."""
task = response.data[0] if response.data else None
if task and task.outputs:
for output in task.outputs:
if output.type == "video" and output.url:
return output.url
raise RuntimeError(f"Kling 3.0 Turbo task finished without a video output: {response.model_dump()}")
async def execute_kling_turbo(
cls: type[IO.ComfyNode],
*,
prompt: str,
resolution: str,
aspect_ratio: str,
duration: int,
start_frame: torch.Tensor | None,
) -> IO.NodeOutput:
"""Create + poll a Kling 3.0 Turbo task. Image-to-video when start_frame is given, else text-to-video."""
if start_frame is not None:
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
contents = [Kling3TurboContent(type="first_frame", url=tensor_to_base64_string(start_frame))]
if prompt:
contents.insert(0, Kling3TurboContent(type="prompt", text=prompt))
create = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/image-to-video/kling-3.0-turbo", method="POST"),
response_model=Kling3TurboCreateResponse,
data=Kling3TurboImage2VideoRequest(
contents=contents,
settings=Kling3TurboSettings(resolution=resolution, duration=duration), # i2v: no aspect_ratio
),
)
else:
create = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/text-to-video/kling-3.0-turbo", method="POST"),
response_model=Kling3TurboCreateResponse,
data=Kling3TurboText2VideoRequest(
prompt=prompt,
settings=Kling3TurboSettings(resolution=resolution, aspect_ratio=aspect_ratio, duration=duration),
),
)
if not (create.data and create.data.id):
raise RuntimeError(f"Kling 3.0 Turbo create failed. Code: {create.code}, Message: {create.message}")
final_response = await poll_op(
cls,
ApiEndpoint(path="/proxy/kling/tasks", query_params={"task_ids": create.data.id}),
response_model=Kling3TurboQueryResponse,
status_extractor=lambda r: (r.data[0].status if r.data else None),
)
return IO.NodeOutput(await download_url_to_video_output(_turbo_video_url(final_response)))
class KlingVideoNode(IO.ComfyNode):
@classmethod
@ -2884,7 +2951,11 @@ class KlingVideoNode(IO.ComfyNode):
],
tooltip="Generate a series of video segments with individual prompts and durations.",
),
IO.Boolean.Input("generate_audio", default=True),
IO.Boolean.Input(
"generate_audio",
default=True,
tooltip="'kling-3.0-turbo' always generates native audio, so the audio toggle is ignored.",
),
IO.DynamicCombo.Input(
"model",
options=[
@ -2899,6 +2970,17 @@ class KlingVideoNode(IO.ComfyNode):
),
],
),
IO.DynamicCombo.Option(
"kling-3.0-turbo",
[
IO.Combo.Input("resolution", options=["1080p", "720p"], default="720p"),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1"],
tooltip="Ignored in image-to-video mode.",
),
],
),
],
tooltip="Model and generation settings.",
),
@ -2930,6 +3012,7 @@ class KlingVideoNode(IO.ComfyNode):
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=[
"model",
"model.resolution",
"generate_audio",
"multi_shot",
@ -2944,14 +3027,7 @@ class KlingVideoNode(IO.ComfyNode):
),
expr="""
(
$rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio);
$ms := widgets.multi_shot;
$isSb := $ms != "disabled";
$n := $isSb ? $number($substring($ms, 0, 1)) : 0;
@ -2962,7 +3038,18 @@ class KlingVideoNode(IO.ComfyNode):
$d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0;
$d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0;
$dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration");
{"type":"usd","usd": $rate * $dur}
widgets.model = "kling-3.0-turbo"
? {"type":"usd","usd": ($res = "1080p" ? 0.14 : 0.112) * $dur}
: (
$rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio);
{"type":"usd","usd": $rate * $dur}
)
)
""",
),
@ -3015,6 +3102,17 @@ class KlingVideoNode(IO.ComfyNode):
duration = multi_shot["duration"]
validate_string(multi_shot["prompt"], min_length=1, max_length=2500)
if model["model"] == "kling-3.0-turbo":
turbo_prompt = build_turbo_shot_prompt(multi_prompt_list) if custom_multi_shot else multi_shot["prompt"]
return await execute_kling_turbo(
cls,
prompt=turbo_prompt,
resolution=model["resolution"],
aspect_ratio=model["aspect_ratio"],
duration=duration,
start_frame=start_frame,
)
if start_frame is not None:
validate_image_dimensions(start_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))

View File

@ -42,9 +42,11 @@ async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Ima
_MODEL_MEDIUM = "Krea 2 Medium"
_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo"
_MODEL_LARGE = "Krea 2 Large"
_MODEL_ENDPOINTS: dict[str, str] = {
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
_MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo",
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
}
@ -57,7 +59,7 @@ _UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F
def _krea_model_inputs() -> list:
"""Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
"""Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo."""
return [
IO.Combo.Input(
"aspect_ratio",
@ -123,6 +125,7 @@ class Krea2ImageNode(IO.ComfyNode):
"model",
options=[
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()),
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
],
tooltip="Krea 2 Medium is best for expressive illustrations; "
@ -151,14 +154,15 @@ class Krea2ImageNode(IO.ComfyNode):
),
expr="""
(
$isLarge := widgets.model = "krea 2 large";
$rates := {
"krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02},
"krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04},
"krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07}
};
$r := $lookup($rates, widgets.model);
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
$hasStyle := $lookup(inputs, "model.style_reference").connected;
$usd := $hasMoodboard
? ($isLarge ? 0.07 : 0.04)
: ($hasStyle
? ($isLarge ? 0.065 : 0.035)
: ($isLarge ? 0.06 : 0.03));
$usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text);
{"type":"usd","usd": $usd}
)
""",

View File

@ -3,9 +3,13 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.luma import (
LUMA_KEYFRAME_MODE_FRACTION,
LUMA_KEYFRAME_MODE_SECONDS,
Luma2Generation,
Luma2GenerationRequest,
Luma2ImageRef,
Luma2VideoEdit,
Luma2VideoOptions,
LumaAspectRatio,
LumaCharacterRef,
LumaConceptChain,
@ -18,6 +22,8 @@ from comfy_api_nodes.apis.luma import (
LumaIO,
LumaKeyframes,
LumaModifyImageRef,
LumaRay32KeyframeChain,
LumaRay32KeyframeItem,
LumaReference,
LumaReferenceChain,
LumaVideoModel,
@ -33,6 +39,7 @@ from comfy_api_nodes.util import (
sync_op,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
)
@ -692,7 +699,10 @@ async def _luma2_upload_image_refs(
async def _luma2_submit_and_poll(
cls: type[IO.ComfyNode],
request: Luma2GenerationRequest,
) -> Input.Image:
*,
estimated_duration: int | None = None,
) -> Luma2Generation:
"""Submit a Luma Agents generation and poll until done; returns the completed generation."""
initial = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma_2/generations", method="POST"),
@ -700,21 +710,21 @@ async def _luma2_submit_and_poll(
data=request,
)
if not initial.id:
raise RuntimeError("Luma 2 API did not return a generation id.")
raise RuntimeError("Luma API did not return a generation id.")
final = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"),
response_model=Luma2Generation,
status_extractor=lambda r: r.state,
progress_extractor=lambda r: None,
estimated_duration=estimated_duration,
)
if not final.output:
if not final.output or not final.output[0].url:
msg = final.failure_reason or "no output returned"
raise RuntimeError(f"Luma 2 generation failed: {msg}")
url = final.output[0].url
if not url:
raise RuntimeError("Luma 2 generation completed without an output URL.")
return await download_url_to_image_tensor(url)
if final.failure_code:
msg = f"{msg} [{final.failure_code}]"
raise RuntimeError(f"Luma generation failed: {msg}")
return final
class LumaImageNode(IO.ComfyNode):
@ -843,7 +853,8 @@ class LumaImageNode(IO.ComfyNode):
web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9),
)
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
final = await _luma2_submit_and_poll(cls, request)
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
class LumaImageEditNode(IO.ComfyNode):
@ -929,7 +940,533 @@ class LumaImageEditNode(IO.ComfyNode):
web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8),
)
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
final = await _luma2_submit_and_poll(cls, request)
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
_BADGE_RAY32_VIDEO = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]),
expr="""
(
$p := {
"360p": {"5s": 0.06, "10s": 0.18},
"540p": {"5s": 0.15, "10s": 0.45},
"720p": {"5s": 0.3, "10s": 0.9},
"1080p": {"5s": 1.2, "10s": 3.6}
};
{"type": "usd", "usd": $lookup($lookup($p, widgets.resolution), widgets.duration)}
)
""",
)
_BADGE_RAY32_VIDEO_5S = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$p := {"360p": 0.06, "540p": 0.15, "720p": 0.3, "1080p": 1.2};
{"type": "usd", "usd": $lookup($p, widgets.resolution)}
)
""",
)
_BADGE_RAY32_EDIT = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$p := {
"360p": {"min": 0.54, "max": 1.08},
"540p": {"min": 0.72, "max": 1.44},
"720p": {"min": 1.08, "max": 2.16},
"1080p": {"min": 2.16, "max": 4.32}
};
$r := $lookup($p, widgets.resolution);
{"type": "range_usd", "min_usd": $r.min, "max_usd": $r.max, "format": {"note": "(by source length)"}}
)
""",
)
_BADGE_RAY32_REFRAME = IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
expr="""
(
$p := {"360p": 0.03, "540p": 0.06, "720p": 0.12, "1080p": 0.36};
{"type": "usd", "usd": $lookup($p, widgets.resolution), "format": {"suffix": "/second"}}
)
""",
)
def _ray32_seed_input() -> IO.Input:
return IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; results are nondeterministic regardless of seed.",
)
async def _ray32_generate(cls: type[IO.ComfyNode], request: Luma2GenerationRequest) -> IO.NodeOutput:
"""Run a ray-3.2 generation and return (video, generation_id)."""
final = await _luma2_submit_and_poll(cls, request, estimated_duration=120)
video = await download_url_to_video_output(final.output[0].url)
return IO.NodeOutput(video, final.id or "")
class LumaRay32TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32TextToVideoNode",
display_name="Luma Ray 3.2 Text to Video",
category="partner/video/Luma",
description="Generate a video from a text prompt using Luma's Ray 3.2 model.",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Combo.Input("duration", options=["5s", "10s"]),
IO.Boolean.Input(
"loop",
default=False,
tooltip="Make the video loop seamlessly. Only available with 5s duration.",
),
_ray32_seed_input(),
],
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO,
)
@classmethod
async def execute(
cls, prompt: str, aspect_ratio: str, resolution: str, duration: str, loop: bool, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
if loop and duration == "10s":
raise ValueError("Looping is only available with 5s duration on Ray 3.2.")
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video",
aspect_ratio=aspect_ratio,
video=Luma2VideoOptions(resolution=resolution, duration=duration, loop=loop or None),
)
return await _ray32_generate(cls, request)
class LumaRay32ImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32ImageToVideoNode",
display_name="Luma Ray 3.2 Image to Video",
category="partner/video/Luma",
description="Generate a video from a start and/or end frame using Luma's Ray 3.2 model. "
"Image-anchored generations are always 5 seconds.",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Boolean.Input(
"loop",
default=False,
tooltip="Make the video loop seamlessly. Not available when an end_frame is set.",
),
_ray32_seed_input(),
IO.Image.Input("start_frame", optional=True, tooltip="First frame of the generated video."),
IO.Image.Input("end_frame", optional=True, tooltip="Last frame of the generated video."),
],
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO_5S,
)
@classmethod
async def execute(
cls,
prompt: str,
resolution: str,
loop: bool,
seed: int,
start_frame: torch.Tensor | None = None,
end_frame: torch.Tensor | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
if start_frame is None and end_frame is None:
raise ValueError("Provide at least one of start_frame / end_frame.")
if loop and end_frame is not None:
raise ValueError("Looping is not available when an end_frame is set.")
video = Luma2VideoOptions(resolution=resolution, duration="5s", loop=loop or None)
if start_frame is not None:
url = await upload_image_to_comfyapi(cls, start_frame, mime_type="image/png")
video.start_frame = Luma2ImageRef(url=url)
if end_frame is not None:
url = await upload_image_to_comfyapi(cls, end_frame, mime_type="image/png")
video.end_frame = Luma2ImageRef(url=url)
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
return await _ray32_generate(cls, request)
class LumaRay32KeyframeNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32KeyframeNode",
display_name="Luma Ray 3.2 Keyframe",
category="partner/video/Luma",
description="Anchor a guide image to a position on the Ray 3.2 output video timeline. Connect this to "
"the 'keyframes' input of the Luma Ray 3.2 Keyframes to Video node; chain several together via the "
"optional 'keyframes' input below.",
inputs=[
IO.Image.Input("image", tooltip="Guide image to place at the chosen moment of the output video."),
IO.DynamicCombo.Input(
"position",
options=[
IO.DynamicCombo.Option(
"Fraction of duration (0.0-1.0)",
[
IO.Float.Input(
"fraction",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.number,
tooltip="Where in the output video this image applies " "(0.0 = start, 1.0 = end).",
),
],
),
IO.DynamicCombo.Option(
"Absolute time (seconds)",
[
IO.Float.Input(
"seconds",
default=0.0,
min=0.0,
max=10.0,
step=0.1,
display_mode=IO.NumberDisplay.number,
tooltip="Time in seconds from the start of the output video where this "
"image applies.",
),
],
),
],
tooltip="How to place this image on the output video's timeline.",
),
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
"keyframes",
optional=True,
tooltip="Optional earlier keyframes to chain with this one.",
),
],
outputs=[IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Output(display_name="keyframes")],
)
@classmethod
def execute(
cls,
image: torch.Tensor,
position: dict,
keyframes: LumaRay32KeyframeChain | None = None,
) -> IO.NodeOutput:
chain = keyframes.clone() if keyframes is not None else LumaRay32KeyframeChain()
if position["position"] == "Absolute time (seconds)":
mode, value = LUMA_KEYFRAME_MODE_SECONDS, float(position["seconds"])
else:
mode, value = LUMA_KEYFRAME_MODE_FRACTION, float(position["fraction"])
chain.add(LumaRay32KeyframeItem(image=image, mode=mode, value=value))
return IO.NodeOutput(chain)
class LumaRay32KeyframesToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32KeyframesToVideoNode",
display_name="Luma Ray 3.2 Keyframes to Video",
category="partner/video/Luma",
description="Generate a video that interpolates through a sequence of guide images, each anchored to a "
"position on the timeline, using Luma Ray 3.2. Build the sequence with Luma Ray 3.2 Keyframe nodes "
"(at least 2).",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Combo.Input("duration", options=["5s", "10s"]),
_ray32_seed_input(),
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
"keyframes",
tooltip="Keyframe sequence from Luma Ray 3.2 Keyframe nodes (at least 2).",
),
],
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO,
)
@classmethod
async def execute(
cls,
prompt: str,
resolution: str,
duration: str,
seed: int,
keyframes: LumaRay32KeyframeChain | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
items = keyframes.items if keyframes is not None else []
if len(items) < 2:
raise ValueError(
"Connect at least 2 Luma Ray 3.2 Keyframe nodes "
"(use Luma Ray 3.2 Image to Video for a single start/end frame)."
)
if len(items) > 64:
raise ValueError(f"Ray 3.2 supports at most 64 keyframes; got {len(items)}.")
maxframe = 120 if duration == "5s" else 240
duration_seconds = maxframe / 24 # 5.0 or 10.0
# Resolve each keyframe to an output-frame index, then order by position
# (so the user can chain keyframes in any order — the position is what places them)
placed: list[tuple[int, torch.Tensor]] = []
for item in items:
if item.mode == LUMA_KEYFRAME_MODE_SECONDS:
if item.value > duration_seconds:
raise ValueError(
f"Keyframe position {item.value:g}s is past the end of the {duration} video; "
f"use 0-{duration_seconds:g}s (or switch the keyframe to fraction mode)."
)
idx = round(item.value * 24)
else:
idx = round(item.value * maxframe)
placed.append((max(0, min(maxframe, idx)), item.image))
placed.sort(key=lambda p: p[0])
indexes = [idx for idx, _ in placed]
for a, b in zip(indexes, indexes[1:]):
if a == b:
raise ValueError(
f"Two keyframes resolve to the same output frame ({a}) for a {duration} video "
f"(valid range 0-{maxframe}); give each keyframe a distinct position."
)
refs: list[Luma2ImageRef] = []
for _, image in placed:
url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
refs.append(Luma2ImageRef(url=url))
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video",
video=Luma2VideoOptions(resolution=resolution, duration=duration, keyframes=refs, keyframe_indexes=indexes),
)
return await _ray32_generate(cls, request)
class LumaRay32VideoEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32VideoEditNode",
display_name="Luma Ray 3.2 Video Edit",
category="partner/video/Luma",
description="Re-render an existing video under a new prompt using Luma Ray 3.2 (restyle, relight, add "
"or remove elements) while keeping the original motion. Source video up to 18 seconds; the edited "
"video keeps the source's length.",
inputs=[
IO.Video.Input("video", tooltip="Source video to edit. Up to 18 seconds."),
IO.String.Input("prompt", multiline=True, default="", tooltip="Describes the desired edit."),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
IO.Combo.Input(
"strength",
options=[
"auto",
"adhere_1",
"adhere_2",
"adhere_3",
"flex_1",
"flex_2",
"flex_3",
"reimagine_1",
"reimagine_2",
"reimagine_3",
],
default="auto",
tooltip="How strongly to preserve vs. reimagine the source. 'auto' lets Ray 3.2 choose; "
"adhere_* preserves the most, flex_* is balanced, reimagine_* changes the most.",
),
_ray32_seed_input(),
],
outputs=[
IO.Video.Output(),
IO.String.Output(display_name="generation_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_EDIT,
)
@classmethod
async def execute(
cls, video: Input.Video, prompt: str, resolution: str, strength: str, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
try:
duration = "5s" if video.get_duration() <= 5.0 else "10s"
except Exception:
duration = "10s"
source_url = await upload_video_to_comfyapi(cls, video, max_duration=18)
edit = Luma2VideoEdit(auto_controls=True) if strength == "auto" else Luma2VideoEdit(strength=strength)
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video_edit",
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
video=Luma2VideoOptions(resolution=resolution, duration=duration, edit=edit),
)
return await _ray32_generate(cls, request)
class LumaRay32VideoReframeNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32VideoReframeNode",
display_name="Luma Ray 3.2 Video Reframe",
category="partner/video/Luma",
description="Change the aspect ratio of an existing video, using Luma Ray 3.2 to fill the newly "
"exposed canvas areas. Source video up to 30 seconds. Billed per second of output.",
inputs=[
IO.Video.Input("video", tooltip="Source video to reframe. Up to 30 seconds."),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describes how the newly exposed canvas areas should be filled.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
_ray32_seed_input(),
],
outputs=[
IO.Video.Output(),
IO.String.Output(display_name="generation_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_REFRAME,
)
@classmethod
async def execute(
cls, video: Input.Video, prompt: str, aspect_ratio: str, resolution: str, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
if resolution == "1080p" and aspect_ratio in {"9:16", "3:4"}:
raise ValueError("1080p is not available for vertical aspect ratios (9:16, 3:4) when reframing.")
source_url = await upload_video_to_comfyapi(cls, video, max_duration=30)
request = Luma2GenerationRequest(
prompt=prompt,
model="ray-3.2",
type="video_reframe",
aspect_ratio=aspect_ratio,
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
video=Luma2VideoOptions(resolution=resolution),
)
return await _ray32_generate(cls, request)
class LumaRay32ExtendVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaRay32ExtendVideoNode",
display_name="Luma Ray 3.2 Extend Video",
category="partner/video/Luma",
description="Extend a previous Ray 3.2 generation forward (continue after it) or backward (lead-in "
"before it). Connect the generation_id output of a prior Luma Ray 3.2 node."
" Extensions are always 5 seconds.",
inputs=[
IO.String.Input(
"source_generation_id",
default="",
tooltip="generation_id of the prior Ray 3.2 video to extend."
" Connect the generation_id output of another Luma Ray 3.2 node.",
),
IO.DynamicCombo.Input(
"direction",
options=[
IO.DynamicCombo.Option(
"Forward (continue after)",
[
IO.Boolean.Input(
"loop",
default=False,
tooltip="Loop the extended video seamlessly (forward extend only).",
),
],
),
IO.DynamicCombo.Option("Backward (lead-in before)", []),
],
tooltip="Forward continues after the prior clip; backward is prepended before it.",
),
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the new content."),
IO.Combo.Input("resolution", options=["540p", "720p", "1080p"], default="720p"),
_ray32_seed_input(),
],
outputs=[
IO.Video.Output(),
IO.String.Output(display_name="generation_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=_BADGE_RAY32_VIDEO_5S,
)
@classmethod
async def execute(
cls, source_generation_id: str, direction: dict, prompt: str, resolution: str, seed: int
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
gen_id = (source_generation_id or "").strip()
if not gen_id:
raise ValueError(
"source_generation_id is required (connect the generation_id output of a prior Luma Ray 3.2 node)."
)
video = Luma2VideoOptions(resolution=resolution, duration="5s")
ref = Luma2ImageRef(generation_id=gen_id)
if direction["direction"] == "Forward (continue after)":
video.start_frame = ref
if direction.get("loop"):
video.loop = True
else:
video.end_frame = ref
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
return await _ray32_generate(cls, request)
class LumaExtension(ComfyExtension):
@ -944,6 +1481,13 @@ class LumaExtension(ComfyExtension):
LumaConceptsNode,
LumaImageNode,
LumaImageEditNode,
LumaRay32TextToVideoNode,
LumaRay32ImageToVideoNode,
LumaRay32KeyframeNode,
LumaRay32KeyframesToVideoNode,
LumaRay32VideoEditNode,
LumaRay32VideoReframeNode,
LumaRay32ExtendVideoNode,
]

View File

@ -9,6 +9,7 @@ from PIL import Image
from typing_extensions import override
import folder_paths
from comfy.utils import common_upscale
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.openai import (
InputFileContent,
@ -62,7 +63,8 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
A torch.Tensor of shape (N, H, W, C) with all returned images; images whose
dimensions differ from the first image's are resized to match it.
Raises:
ValueError: If the response is not valid.
@ -89,6 +91,14 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
# With size="auto" the API can return images whose dimensions differ by a few pixels within a single response
# resize them to the first image's dimensions so they can be stacked into one batch.
ref_h, ref_w = image_tensors[0].shape[:2]
for i, t in enumerate(image_tensors):
if t.shape[:2] != (ref_h, ref_w):
samples = t.unsqueeze(0).movedim(-1, 1)
samples = common_upscale(samples, ref_w, ref_h, "bilinear", "center")
image_tensors[i] = samples.movedim(1, -1).squeeze(0)
return torch.stack(image_tensors, dim=0)

View File

@ -30,13 +30,33 @@ from comfy_api_nodes.apis.runway import (
Model4,
ReferenceImage,
RunwayTextToImageAspectRatioEnum,
RunwayAleph2IO,
RunwayAleph2KeyframeChain,
RunwayAleph2KeyframeItem,
RunwayAleph2PromptImageChain,
RunwayAleph2PromptImageItem,
RunwayAleph2Request,
RunwayAleph2Response,
RunwayAleph2KeyframeSeconds,
RunwayAleph2KeyframeAt,
RunwayAleph2PromptImage,
RunwayAleph2TimestampPosition,
RunwayAleph2RelativePosition,
RunwayAleph2ContentModeration,
KEYFRAME_MODE_SECONDS,
KEYFRAME_MODE_AT,
PROMPT_IMAGE_MODE_TIMESTAMP,
PROMPT_IMAGE_MODE_POSITION,
)
from comfy_api_nodes.util import (
image_tensor_pair_to_batch,
validate_string,
validate_image_dimensions,
validate_image_aspect_ratio,
validate_video_duration,
upload_images_to_comfyapi,
upload_image_to_comfyapi,
upload_video_to_comfyapi,
download_url_to_video_output,
download_url_to_image_tensor,
ApiEndpoint,
@ -45,6 +65,7 @@ from comfy_api_nodes.util import (
)
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
@ -53,12 +74,6 @@ AVERAGE_DURATION_FLF_SECONDS = 256
AVERAGE_DURATION_T2I_SECONDS = 41
class RunwayApiError(Exception):
"""Base exception for Runway API errors."""
pass
class RunwayGen4TurboAspectRatio(str, Enum):
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
@ -84,14 +99,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
return None
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> float | None:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
@ -102,14 +109,13 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
async def get_response(
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_op(
cls,
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
response_model=TaskStatusResponse,
status_extractor=lambda r: r.status.value,
status_extractor=lambda r: r.status,
estimated_duration=estimated_duration,
progress_extractor=extract_progress_from_task_status,
progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None,
)
@ -127,7 +133,7 @@ async def generate_video(
final_response = await get_response(cls, initial_response.id, estimated_duration)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no video data found in response.")
raise ValueError("Runway task succeeded but no video data found in response.")
video_url = get_video_url_from_task_status(final_response)
return await download_url_to_video_output(video_url)
@ -410,7 +416,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
mime_type="image/png",
)
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
raise ValueError("Failed to upload one or more images to comfy api.")
return IO.NodeOutput(
await generate_video(
@ -514,11 +520,321 @@ class RunwayTextToImageNode(IO.ComfyNode):
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
)
if not final_response.output:
raise RunwayApiError("Runway task succeeded but no image data found in response.")
raise ValueError("Runway task succeeded but no image data found in response.")
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
_TIMING_ABSOLUTE = "Absolute time (seconds)"
_TIMING_FRACTION = "Fraction of duration (0.0-1.0)"
class RunwayAleph2KeyframeNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RunwayAleph2KeyframeNode",
display_name="Runway Aleph2 Keyframe",
category="partner/video/Runway",
description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 "
"steers the edit at that point of your footage. Connect this to the 'keyframes' input of "
"the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
"'keyframes' input below.",
inputs=[
IO.Image.Input(
"image",
tooltip="The guidance image to apply at the chosen moment of the input video.",
),
IO.DynamicCombo.Input(
"timing",
options=[
IO.DynamicCombo.Option(
_TIMING_ABSOLUTE,
[
IO.Float.Input(
"seconds",
default=0.0,
min=0.0,
max=30.0,
step=0.1,
display_mode=IO.NumberDisplay.number,
tooltip="Time in seconds from start of the input video where this image applies.",
),
],
),
IO.DynamicCombo.Option(
_TIMING_FRACTION,
[
IO.Float.Input(
"fraction",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.number,
tooltip="Where in the input video this image applies, "
"as a fraction of its duration (0.0 = start, 1.0 = end).",
),
],
),
],
tooltip="How to place this image on the input video's timeline.",
),
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
"keyframes",
optional=True,
tooltip="Optional earlier keyframes to chain with this one.",
),
],
outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")],
)
@classmethod
def execute(
cls,
image: Input.Image,
timing: dict,
keyframes: RunwayAleph2KeyframeChain | None = None,
) -> IO.NodeOutput:
chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain()
if timing["timing"] == _TIMING_ABSOLUTE:
mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"])
else:
mode, value = KEYFRAME_MODE_AT, float(timing["fraction"])
chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value))
return IO.NodeOutput(chain)
class RunwayAleph2PromptImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RunwayAleph2PromptImageNode",
display_name="Runway Aleph2 Prompt Image",
category="partner/video/Runway",
description="Anchor a guidance image to a moment of the output (result) video, to guide what "
"the edited video looks like at that point. Connect this to the 'prompt_images' input of the "
"Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
"'prompt_images' input below.",
inputs=[
IO.Image.Input(
"image",
tooltip="The guidance image to place at the chosen moment of the output video.",
),
IO.DynamicCombo.Input(
"position",
options=[
IO.DynamicCombo.Option(
_TIMING_ABSOLUTE,
[
IO.Float.Input(
"seconds",
default=0.0,
min=0.0,
max=30.0,
step=0.1,
display_mode=IO.NumberDisplay.number,
tooltip="Time in seconds from start of the output video where this image applies.",
),
],
),
IO.DynamicCombo.Option(
_TIMING_FRACTION,
[
IO.Float.Input(
"fraction",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.number,
tooltip="Where in the output video this image applies, "
"as a fraction of its duration (0.0 = start, 1.0 = end).",
),
],
),
],
tooltip="How to place this image on the output video's timeline.",
),
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
"prompt_images",
optional=True,
tooltip="Optional earlier prompt images to chain with this one.",
),
],
outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")],
)
@classmethod
def execute(
cls,
image: Input.Image,
position: dict,
prompt_images: RunwayAleph2PromptImageChain | None = None,
) -> IO.NodeOutput:
chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain()
if position["position"] == _TIMING_ABSOLUTE:
mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"])
else:
mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"])
chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value))
return IO.NodeOutput(chain)
class RunwayAleph2VideoToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RunwayAleph2VideoToVideoNode",
display_name="Runway Aleph2 Video to Video",
category="partner/video/Runway",
description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms "
"your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping "
"the original motion and timing; the output resolution matches the input video, which must be "
"2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to "
"the input video) or prompt images (anchored to the output video) - use one or the other, not both.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describes what should appear in the output (1-1000 characters).",
),
IO.Video.Input(
"video",
tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=4294967295,
step=1,
control_after_generate=True,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed for generation",
),
IO.Combo.Input(
"public_figure_threshold",
options=["auto", "low"],
default="low",
tooltip="Content moderation for recognizable public figures.",
),
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
"keyframes",
optional=True,
tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). "
"Use keyframes or prompt images, not both.",
),
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
"prompt_images",
optional=True,
tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). "
"Use keyframes or prompt images, not both.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
video: Input.Video,
seed: int,
public_figure_threshold: str = "low",
keyframes: RunwayAleph2KeyframeChain | None = None,
prompt_images: RunwayAleph2PromptImageChain | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=1000)
validate_video_duration(
video,
min_duration=2.0,
max_duration=30.0,
)
try:
fps = float(video.get_frame_rate())
except Exception:
fps = None
if fps is not None and fps > 30.0 + 0.01:
raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.")
if (keyframes and keyframes.items) and (prompt_images and prompt_images.items):
raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.")
video_duration: float | None = None
try:
video_duration = video.get_duration()
except Exception:
video_duration = None
def _check_seconds(value: float, label: str) -> None:
if video_duration is not None and value > video_duration + 0.0001:
raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).")
video_url = await upload_video_to_comfyapi(cls, video)
keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = []
if keyframes is not None:
if len(keyframes.items) > 5:
raise ValueError("Aleph2 supports at most 5 keyframes.")
for item in keyframes.items:
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
if item.mode == KEYFRAME_MODE_SECONDS:
_check_seconds(item.value, "Keyframe timestamp")
keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url))
else:
keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url))
prompt_image_models: list[RunwayAleph2PromptImage] = []
if prompt_images is not None:
if len(prompt_images.items) > 5:
raise ValueError("Aleph2 supports at most 5 prompt images.")
for item in prompt_images.items:
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP:
_check_seconds(item.value, "Prompt image timestamp")
position = RunwayAleph2TimestampPosition(timestampSeconds=item.value)
else:
position = RunwayAleph2RelativePosition(positionPercentage=item.value)
prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url))
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"),
response_model=RunwayAleph2Response,
data=RunwayAleph2Request(
promptText=prompt,
videoUri=video_url,
seed=seed,
contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold),
keyframes=keyframe_models or None,
promptImage=prompt_image_models or None,
),
)
final_response = await get_response(cls, initial_response.id)
if not final_response.output:
raise ValueError("Runway task succeeded but no video data found in response.")
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response)))
class RunwayExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -527,6 +843,9 @@ class RunwayExtension(ComfyExtension):
RunwayImageToVideoNodeGen3a,
RunwayImageToVideoNodeGen4,
RunwayTextToImageNode,
RunwayAleph2VideoToVideoNode,
RunwayAleph2KeyframeNode,
RunwayAleph2PromptImageNode,
]

View File

@ -16,7 +16,7 @@ from comfy_api_nodes.util import (
)
from comfy_api_nodes.util._helpers import (
default_base_url,
get_auth_header,
get_comfy_api_headers,
get_node_id,
is_processing_interrupted,
)
@ -100,8 +100,7 @@ class SoniloTextToMusic(IO.ComfyNode):
node_id="SoniloTextToMusic",
display_name="Sonilo Text to Music",
category="partner/audio/Sonilo",
description="Generate music from a text prompt using Sonilo's AI model. "
"Leave duration at 0 to let the model infer it from the prompt.",
description="Generate music from a text prompt using Sonilo's AI model.",
inputs=[
IO.String.Input(
"prompt",
@ -111,11 +110,10 @@ class SoniloTextToMusic(IO.ComfyNode):
),
IO.Int.Input(
"duration",
default=0,
min=0,
default=30,
min=1,
max=360,
tooltip="Target duration in seconds. Set to 0 to let the model "
"infer the duration from the prompt. Maximum: 6 minutes.",
tooltip="Target duration in seconds. Maximum: 6 minutes.",
),
IO.Int.Input(
"seed",
@ -136,13 +134,7 @@ class SoniloTextToMusic(IO.ComfyNode):
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
expr="""
(
widgets.duration > 0
? {"type":"usd","usd": 0.005 * widgets.duration}
: {"type":"usd","usd": 0.005, "format":{"suffix":"/second"}}
)
""",
expr='{"type":"usd","usd": 0.0025 * widgets.duration}',
),
)
@ -150,14 +142,13 @@ class SoniloTextToMusic(IO.ComfyNode):
async def execute(
cls,
prompt: str,
duration: int = 0,
duration: int = 1,
seed: int = 0,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=1000)
form = aiohttp.FormData()
form.add_field("prompt", prompt)
if duration > 0:
form.add_field("duration", str(duration))
form.add_field("duration", str(duration))
audio_bytes = await _stream_sonilo_music(
cls,
ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"),
@ -174,8 +165,7 @@ async def _stream_sonilo_music(
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
headers: dict[str, str] = {}
headers.update(get_auth_header(cls))
headers = get_comfy_api_headers(cls)
headers.update(endpoint.headers)
node_id = get_node_id(cls)

View File

@ -1,6 +1,6 @@
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.tripo import (
TripoAnimateRetargetRequest,
TripoAnimateRigRequest,
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.tripo import (
TripoFileEmptyReference,
TripoFileReference,
TripoImageToModelRequest,
TripoImportModelRequest,
TripoModelVersion,
TripoMultiviewToModelRequest,
TripoOrientation,
@ -21,6 +22,7 @@ from comfy_api_nodes.apis.tripo import (
TripoTaskType,
TripoTextToModelRequest,
TripoTextureModelRequest,
TripoTexturePrompt,
TripoUrlReference,
)
from comfy_api_nodes.util import (
@ -28,6 +30,7 @@ from comfy_api_nodes.util import (
download_url_to_file_3d,
poll_op,
sync_op,
upload_3d_model_to_comfyapi,
upload_images_to_comfyapi,
)
@ -538,6 +541,14 @@ class TripoTextureNode(IO.ComfyNode):
optional=True,
advanced=True,
),
IO.String.Input(
"texture_prompt",
default="",
multiline=True,
optional=True,
tooltip="Optional text guidance for texturing. Required in practice for imported "
"models (Tripo: Import Model), which carry no source image to infer colors from.",
),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -571,6 +582,7 @@ class TripoTextureNode(IO.ComfyNode):
texture_seed: int | None = None,
texture_quality: str | None = None,
texture_alignment: str | None = None,
texture_prompt: str = "",
) -> IO.NodeOutput:
response = await sync_op(
cls,
@ -583,6 +595,7 @@ class TripoTextureNode(IO.ComfyNode):
texture_seed=texture_seed,
texture_quality=texture_quality,
texture_alignment=texture_alignment,
texture_prompt=TripoTexturePrompt(text=texture_prompt.strip()) if texture_prompt.strip() else None,
),
)
return await poll_until_finished(cls, response, average_duration=80)
@ -915,6 +928,90 @@ class TripoConversionNode(IO.ComfyNode):
return await poll_until_finished(cls, response, average_duration=30)
class TripoImportModelNode(IO.ComfyNode):
"""Imports an external 3D model into Tripo, producing a MODEL_TASK_ID for post-processing nodes."""
SUPPORTED_FORMATS = ("glb", "fbx", "obj", "stl")
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoImportModelNode",
display_name="Tripo: Import Model",
category="partner/3d/Tripo",
description="Import an external 3D model (e.g. from Rodin, Hunyuan3D or a local file) into Tripo "
"to use it with Tripo's post-processing nodes: Texture, Rig, Convert. "
"GLB is recommended: textures survive import only when embedded in the file. "
"Note that texturing an imported model requires a texture prompt.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DGLB, IO.File3DFBX, IO.File3DOBJ, IO.File3DSTL, IO.File3DAny],
tooltip="3D model to import (GLB / FBX / OBJ / STL, up to 150 MB). "
"OBJ and STL files carry no embedded textures.",
),
],
outputs=[
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"text","text":"Free"}""",
),
)
@classmethod
async def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput:
file_format = (model_3d.format or "").lstrip(".").lower()
if file_format == "gltf":
raise ValueError(
"GLTF (.gltf) references external files and cannot be imported. Export a single-file GLB instead."
)
if file_format not in cls.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported 3D format '{file_format or 'unknown'}'. "
f"Tripo import supports: {', '.join(f.upper() for f in cls.SUPPORTED_FORMATS)}."
)
size = len(model_3d.get_bytes())
if size > 150 * 1024 * 1024:
raise ValueError(f"Model file is {size / (1024 * 1024):.1f} MB; Tripo import allows up to 150 MB.")
url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/import", method="POST"),
response_model=TripoTaskResponse,
data=TripoImportModelRequest(url=url, format=file_format),
)
if response.code != 0:
raise RuntimeError(f"Failed to import model: {response.error}")
task_id = response.data.task_id
response_poll = await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"),
response_model=TripoTaskResponse,
failed_statuses=[
TripoTaskStatus.FAILED,
TripoTaskStatus.CANCELLED,
TripoTaskStatus.UNKNOWN,
TripoTaskStatus.BANNED,
TripoTaskStatus.EXPIRED,
],
status_extractor=lambda x: x.data.status,
progress_extractor=lambda x: x.data.progress,
estimated_duration=10,
)
if response_poll.data.status != TripoTaskStatus.SUCCESS:
raise RuntimeError(f"Failed to import model: {response_poll}")
return IO.NodeOutput(task_id)
def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str:
return (
"("
@ -1292,6 +1389,7 @@ class TripoExtension(ComfyExtension):
TripoP1TextToModelNode,
TripoP1ImageToModelNode,
TripoP1MultiviewToModelNode,
TripoImportModelNode,
TripoTextureNode,
TripoRefineNode,
TripoRigNode,

View File

@ -48,10 +48,13 @@ from comfy_api_nodes.util import (
upload_image_to_comfyapi,
upload_video_to_comfyapi,
validate_audio_duration,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
validate_video_duration,
)
RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
@ -1657,6 +1660,44 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.1-t2v",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the elements and visual features. "
"Supports English and Chinese.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Combo.Input(
"ratio",
options=[
"16:9",
"9:16",
"1:1",
"4:3",
"3:4",
"21:9",
"9:21",
"5:4",
"4:5",
],
),
IO.Int.Input(
"duration",
default=5,
min=3,
max=15,
step=1,
display_mode=IO.NumberDisplay.number,
),
],
),
IO.DynamicCombo.Option(
"happyhorse-1.0-t2v",
[
@ -1719,7 +1760,9 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$ppsTable := $contains(widgets.model, "1.1")
? { "720p": 0.2002, "1080p": 0.2574 }
: { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps * $dur }
)
@ -1781,6 +1824,30 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.1-i2v",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the elements and visual features. "
"Supports English and Chinese.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Int.Input(
"duration",
default=5,
min=3,
max=15,
step=1,
display_mode=IO.NumberDisplay.number,
),
],
),
IO.DynamicCombo.Option(
"happyhorse-1.0-i2v",
[
@ -1843,7 +1910,9 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$ppsTable := $contains(widgets.model, "1.1")
? { "720p": 0.2002, "1080p": 0.2574 }
: { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps * $dur }
)
@ -1859,6 +1928,8 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
seed: int,
watermark: bool,
):
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1), strict=False)
media = [
Wan27MediaItem(
type="first_frame",
@ -2053,6 +2124,62 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.1-r2v",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the video. Use identifiers such as 'character1' and "
"'character2' to refer to the reference characters.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Combo.Input(
"ratio",
options=[
"16:9",
"9:16",
"1:1",
"4:3",
"3:4",
"21:9",
"9:21",
"5:4",
"4:5",
],
),
IO.Int.Input(
"duration",
default=5,
min=3,
max=15,
step=1,
display_mode=IO.NumberDisplay.number,
),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_image"),
names=[
"image1",
"image2",
"image3",
"image4",
"image5",
"image6",
"image7",
"image8",
"image9",
],
min=1,
),
),
],
),
IO.DynamicCombo.Option(
"happyhorse-1.0-r2v",
[
@ -2133,7 +2260,9 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$ppsTable := $contains(widgets.model, "1.1")
? { "720p": 0.2002, "1080p": 0.2574 }
: { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps * $dur }
)
@ -2149,8 +2278,11 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
watermark: bool,
):
validate_string(model["prompt"], strip_whitespace=False, min_length=1)
media = []
reference_images = model.get("reference_images", {})
for key in reference_images:
validate_image_dimensions(reference_images[key], min_width=400, min_height=400)
validate_image_aspect_ratio(reference_images[key], (1, 2.5), (2.5, 1), strict=False)
media = []
for key in reference_images:
media.append(
Wan27MediaItem(
@ -2159,7 +2291,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
)
)
if not media:
raise ValueError("At least one reference reference image must be provided.")
raise ValueError("At least one reference image must be provided.")
initial_response = await sync_op(
cls,

View File

@ -4,11 +4,14 @@ import os
import re
import time
from collections.abc import Callable
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from io import BytesIO
from yarl import URL
from comfy.cli_args import args
from comfy.deploy_environment import get_deploy_environment
from comfy.model_management import processing_interrupted
from comfy_api.latest import IO
@ -35,6 +38,30 @@ def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
return {}
def get_usage_source(node_cls: type[IO.ComfyNode]) -> str:
"""Source of the prompt that triggered this API node.
Defaults to "comfyui-api" when the submitting client didn't identify itself,
i.e. a direct API call to this server.
"""
return node_cls.hidden.comfy_usage_source or "comfyui-api"
def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
"""Common headers (auth, deploy environment, usage source) for Comfy API requests.
Centralizes the shared header set so every Comfy API request sends a consistent
set and new shared headers only need to be added in one place. Intended for
relative/cloud URLs resolved against ``default_base_url()``; because the result
includes auth, callers must not attach it to arbitrary absolute/presigned URLs.
"""
return {
**get_auth_header(node_cls),
"Comfy-Env": get_deploy_environment(),
"Comfy-Usage-Source": get_usage_source(node_cls),
}
def default_base_url() -> str:
return getattr(args, "comfy_api_base", "https://api.comfy.org")
@ -66,6 +93,32 @@ async def sleep_with_interrupt(
await asyncio.sleep(min(1.0, end - now))
def _retry_after_wait(value: str | None, fallback: float, max_wait: float) -> float:
"""Delay before the next retry, honoring a server ``Retry-After`` header."""
seconds: float | None = None
if value is not None:
value = value.strip()
if value.isascii() and value.isdigit():
# delay-seconds form. The ASCII-digit guard keeps exotic Unicode "digit" characters away from float()
# an all-digit string always converts (huge values become inf, never raising).
seconds = float(value)
elif value:
# HTTP-date form. parsedate_to_datetime raises OverflowError (not a ValueError) on absurd years/offsets
try:
parsed = parsedate_to_datetime(value)
except (TypeError, ValueError, OverflowError):
parsed = None
if parsed is not None:
if parsed.tzinfo is None: # naive datetime: HTTP-date is UTC
parsed = parsed.replace(tzinfo=timezone.utc)
delta = (parsed - datetime.now(timezone.utc)).total_seconds()
seconds = delta if delta > 0 else 0.0
if seconds is None:
return fallback
return min(seconds, max_wait)
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()

View File

@ -19,12 +19,11 @@ from comfy import utils
from comfy_api.latest import IO
from server import PromptServer
from comfy.deploy_environment import get_deploy_environment
from . import request_logger
from ._helpers import (
_retry_after_wait,
default_base_url,
get_auth_header,
get_comfy_api_headers,
get_node_id,
is_processing_interrupted,
sleep_with_interrupt,
@ -84,6 +83,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
_MAX_RETRY_AFTER_WAIT = 150.0 # Cap a server Retry-After at this many seconds so a large hint can't block execution
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
@ -645,8 +645,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
payload_headers["Comfy-Env"] = get_deploy_environment()
payload_headers.update(get_comfy_api_headers(cfg.node_cls))
if cfg.endpoint.headers:
payload_headers.update(cfg.endpoint.headers)
@ -750,6 +749,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
should_retry = True
if should_retry:
wait_time = _retry_after_wait(resp.headers.get("Retry-After"), wait_time, _MAX_RETRY_AFTER_WAIT)
logging.warning(
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
method,

View File

@ -17,7 +17,7 @@ from folder_paths import get_output_directory
from . import request_logger
from ._helpers import (
default_base_url,
get_auth_header,
get_comfy_api_headers,
is_processing_interrupted,
sleep_with_interrupt,
to_aiohttp_url,
@ -64,7 +64,7 @@ async def download_url_to_bytesio(
if cls is None:
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
headers = get_auth_header(cls)
headers = get_comfy_api_headers(cls)
while True:
attempt += 1

View File

@ -0,0 +1,66 @@
"""Enrich executed-node output entries with asset id."""
import logging
import os
def enrich_output_with_assets(output_ui: dict) -> dict:
"""Register file-type output entries as assets and inject their ``id``.
Runs at output-processing time, once per produced output, when
--enable-assets is set. Returns a new dict; entries without a resolvable
on-disk file path are left unchanged. Errors are caught per-entry so a
failure never blocks execution or the other entries.
"""
from comfy.cli_args import args
if not args.enable_assets:
return output_ui
import folder_paths
from app.assets.services.ingest import register_file_in_place, DependencyMissingError
enriched = {}
for key, entries in output_ui.items():
if not isinstance(entries, list):
enriched[key] = entries
continue
new_entries = []
for entry in entries:
if not isinstance(entry, dict) or "filename" not in entry or "type" not in entry:
new_entries.append(entry)
continue
try:
base = folder_paths.get_directory_by_type(entry["type"])
if base is None:
new_entries.append(entry)
continue
base_abs = os.path.abspath(base)
abs_path = os.path.abspath(os.path.join(base_abs, entry.get("subfolder") or "", entry["filename"]))
try:
if os.path.commonpath([base_abs, abs_path]) != base_abs:
raise ValueError("escapes base")
except ValueError:
logging.warning("Asset enrichment skipped (path escapes base): %s", entry.get("filename"))
new_entries.append(entry)
continue
if not os.path.isfile(abs_path):
new_entries.append(entry)
continue
# Register unconditionally: the file was just produced, and
# register_file_in_place re-hashes so an overwritten path can
# never carry a stale id.
result = register_file_in_place(
abs_path=abs_path,
name=entry["filename"],
tags=[entry["type"]],
)
entry = dict(entry)
entry["id"] = result.ref.id
except DependencyMissingError:
logging.warning("Asset enrichment skipped (blake3 not available): %s", entry.get("filename"))
except Exception:
logging.warning("Failed to enrich output entry with asset id: %s", entry.get("filename"), exc_info=True)
new_entries.append(entry)
enriched[key] = new_entries
return enriched

View File

@ -3,11 +3,23 @@ Job utilities for the /api/jobs endpoint.
Provides normalization and helper functions for job status tracking.
"""
from typing import Optional
import uuid
from typing import Callable, Optional
from comfy_api.internal import prune_dict
# Result of classifying a job for cancellation.
# 'running' -> job is currently executing (interrupt it)
# 'pending' -> job is queued but not started (dequeue it)
# 'terminal' -> job already finished (present in history); cancel is a no-op
# 'unknown' -> job id is not present anywhere
CANCEL_RUNNING = 'running'
CANCEL_PENDING = 'pending'
CANCEL_TERMINAL = 'terminal'
CANCEL_UNKNOWN = 'unknown'
class JobStatus:
"""Job status constants."""
PENDING = 'pending'
@ -19,6 +31,25 @@ class JobStatus:
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
def validate_job_id(value) -> str:
"""Validate a client-supplied job (prompt) id.
Job ids must be UUIDs in the canonical lowercase hyphenated form. The id
is stored and compared verbatim everywhere downstream history keys,
websocket events, and /interrupt matching so accepting another spelling
would silently rewrite the client's id and then miss every exact-match
lookup. Rejecting loudly beats that.
Returns the id unchanged. Raises ValueError when the value is not a
string in canonical UUID form.
"""
if not isinstance(value, str):
raise ValueError(f"job id must be a string, got {type(value).__name__}")
if str(uuid.UUID(value)) != value:
raise ValueError("job id must be a UUID in canonical lowercase hyphenated form")
return value
# Media types that can be previewed in the frontend
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
@ -387,3 +418,71 @@ def get_all_jobs(
jobs = jobs[:limit]
return (jobs, total_count)
def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str:
"""Classify a job id for cancellation.
Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN.
Queue items are tuples whose second element (index 1) is the prompt_id.
History is a dict keyed by prompt_id, so a job present there has already
finished and cancelling it is a no-op.
"""
for item in running:
if item[1] == prompt_id:
return CANCEL_RUNNING
for item in queued:
if item[1] == prompt_id:
return CANCEL_PENDING
if prompt_id in history:
return CANCEL_TERMINAL
return CANCEL_UNKNOWN
def cancel_job(
prompt_id: str,
running: list,
queued: list,
history: dict,
interrupt: Callable[[str], bool],
dequeue: Callable[[str], bool],
) -> str:
"""Cancel a single job by id, regardless of state.
Maps the cancel onto the runtime's existing mechanics:
- a running job is interrupted via ``interrupt``
- a pending job is removed from the queue via ``dequeue``
- a job that already finished (terminal) is a no-op
- an unknown id is a no-op (callers that need fail-fast behaviour should
validate ids up front with ``classify_job_for_cancel``)
Both ``interrupt`` and ``dequeue`` take the prompt id and return whether
they acted on a job that was *actually* in that state, so the value returned
here reflects what truly happened rather than the (possibly stale)
classification. This matters around the narrow TOCTOU windows where a job
changes state between the caller's snapshot and the action:
- a job classified RUNNING may have finished before ``interrupt`` fires:
``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op).
- a job classified PENDING may have started executing before ``dequeue``
fires: ``dequeue`` returns False, ``interrupt`` then catches the now-
running job and this returns CANCEL_RUNNING. If it had simply finished
instead, both return False and this returns CANCEL_UNKNOWN.
``interrupt`` must be atomic interrupt the job only if it is still the one
running so a cancel can never land on an unrelated prompt that started in
the meantime (see ``execution.PromptQueue.interrupt_if_running``).
"""
classification = classify_job_for_cancel(prompt_id, running, queued, history)
if classification == CANCEL_RUNNING:
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
if classification == CANCEL_PENDING:
if dequeue(prompt_id):
return CANCEL_PENDING
# Left the pending queue between classification and dequeue: if it
# started executing, interrupt the now-running job; otherwise it has
# already finished and the cancel is a genuine no-op.
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
return classification

View File

@ -11,7 +11,7 @@ class TextEncodeAceStepAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TextEncodeAceStepAudio",
category="model/conditioning",
category="model/conditioning/ace",
inputs=[
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
@ -33,7 +33,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TextEncodeAceStepAudio1.5",
category="model/conditioning",
category="model/conditioning/ace",
inputs=[
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
@ -67,7 +67,7 @@ class EmptyAceStepLatentAudio(IO.ComfyNode):
return IO.Schema(
node_id="EmptyAceStepLatentAudio",
display_name="Empty Ace Step 1.0 Latent Audio",
category="model/latent/audio",
category="model/latent/ace",
inputs=[
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
@ -90,7 +90,7 @@ class EmptyAceStep15LatentAudio(IO.ComfyNode):
return IO.Schema(
node_id="EmptyAceStep1.5LatentAudio",
display_name="Empty Ace Step 1.5 Latent Audio",
category="model/latent/audio",
category="model/latent/ace",
inputs=[
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
IO.Int.Input(
@ -111,8 +111,8 @@ class ReferenceAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ReferenceTimbreAudio",
display_name="Reference Audio",
category="advanced/conditioning/audio",
display_name="Set Reference Audio",
category="model/conditioning",
is_experimental=True,
description="This node sets the reference audio for ace step 1.5",
inputs=[

View File

@ -16,7 +16,7 @@ class APG(io.ComfyNode):
return io.Schema(
node_id="APG",
display_name="Adaptive Projected Guidance",
category="model/sampling/custom_sampling",
category="model/sampling/custom",
inputs=[
io.Model.Input("model"),
io.Float.Input(

View File

@ -19,7 +19,7 @@ class EmptyARVideoLatent(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="EmptyARVideoLatent",
category="model/latent/video",
category="model/latent/autoregressive",
inputs=[
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
@ -85,7 +85,7 @@ class ARVideoI2V(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ARVideoI2V",
category="model/conditioning/video_models",
category="model/conditioning/autoregressive",
inputs=[
io.Model.Input("model"),
io.Vae.Input("vae"),

View File

@ -16,7 +16,7 @@ class EmptyLatentAudio(IO.ComfyNode):
return IO.Schema(
node_id="EmptyLatentAudio",
display_name="Empty Latent Audio",
category="model/latent/audio",
category="model/latent",
essentials_category="Audio",
inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
@ -41,7 +41,7 @@ class ConditioningStableAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ConditioningStableAudio",
category="model/conditioning",
category="model/conditioning/stable audio",
inputs=[
IO.Conditioning.Input("positive"),
IO.Conditioning.Input("negative"),
@ -70,7 +70,7 @@ class VAEEncodeAudio(IO.ComfyNode):
node_id="VAEEncodeAudio",
search_aliases=["audio to latent"],
display_name="VAE Encode Audio",
category="model/latent/audio",
category="model/latent",
inputs=[
IO.Audio.Input("audio"),
IO.Vae.Input("vae"),
@ -115,7 +115,7 @@ class VAEDecodeAudio(IO.ComfyNode):
node_id="VAEDecodeAudio",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio",
category="model/latent/audio",
category="model/latent",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
@ -137,7 +137,7 @@ class VAEDecodeAudioTiled(IO.ComfyNode):
node_id="VAEDecodeAudioTiled",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio (Tiled)",
category="model/latent/audio",
category="model/latent",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode):
return IO.Schema(
node_id="SaveAudio",
search_aliases=["export flac"],
display_name="Save Audio (FLAC)",
display_name="Save Audio (FLAC) (DEPRECATED)",
category="audio",
essentials_category="Audio",
inputs=[
@ -166,7 +166,9 @@ class SaveAudio(IO.ComfyNode):
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_deprecated=True,
is_output_node=True,
outputs=[IO.Audio.Output("audio")]
)
@classmethod
@ -174,11 +176,10 @@ class SaveAudio(IO.ComfyNode):
if audio is None:
raise ValueError("SaveAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
audio,
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
)
save_flac = execute # TODO: remove
class SaveAudioMP3(IO.ComfyNode):
@classmethod
@ -186,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode):
return IO.Schema(
node_id="SaveAudioMP3",
search_aliases=["export mp3"],
display_name="Save Audio (MP3)",
display_name="Save Audio (MP3) (DEPRECATED)",
category="audio",
essentials_category="Audio",
inputs=[
@ -195,7 +196,9 @@ class SaveAudioMP3(IO.ComfyNode):
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_deprecated=True,
is_output_node=True,
outputs=[IO.Audio.Output("audio")]
)
@classmethod
@ -203,13 +206,12 @@ class SaveAudioMP3(IO.ComfyNode):
if audio is None:
raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
audio,
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
)
)
save_mp3 = execute # TODO: remove
class SaveAudioOpus(IO.ComfyNode):
@classmethod
@ -217,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode):
return IO.Schema(
node_id="SaveAudioOpus",
search_aliases=["export opus"],
display_name="Save Audio (Opus)",
display_name="Save Audio (Opus) (DEPRECATED)",
category="audio",
inputs=[
IO.Audio.Input("audio"),
@ -225,7 +227,9 @@ class SaveAudioOpus(IO.ComfyNode):
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_deprecated=True,
is_output_node=True,
outputs=[IO.Audio.Output("audio")]
)
@classmethod
@ -233,12 +237,57 @@ class SaveAudioOpus(IO.ComfyNode):
if audio is None:
raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).")
return IO.NodeOutput(
audio,
ui=UI.AudioSaveHelper.get_save_audio_ui(
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
)
)
save_opus = execute # TODO: remove
class SaveAudioAdvanced(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioAdvanced",
search_aliases=["save audio", "export audio", "output audio", "write audio", "flac", "mp3", "opus"],
display_name="Save Audio (Advanced)",
description="Saves the input audio to your ComfyUI output directory.",
category="audio",
inputs=[
IO.Audio.Input("audio", tooltip="The audio to save."),
IO.String.Input(
"filename_prefix",
default="audio/ComfyUI",
tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd%."),
),
IO.DynamicCombo.Input(
"format",
options=[
IO.DynamicCombo.Option("flac", []),
IO.DynamicCombo.Option("mp3", [
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
]),
IO.DynamicCombo.Option("opus", [
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
]),
],
tooltip="The file format in which to save the audio.",
),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[IO.Audio.Output("audio")],
)
@classmethod
def execute(cls, audio, filename_prefix: str, format: dict) -> IO.NodeOutput:
file_format = format.get("format", None)
quality = format.get("quality", None)
if quality:
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality)
else:
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format)
return IO.NodeOutput(audio, ui=ui)
class PreviewAudio(IO.ComfyNode):
@ -254,13 +303,14 @@ class PreviewAudio(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
outputs=[IO.Audio.Output("audio")]
)
@classmethod
def execute(cls, audio) -> IO.NodeOutput:
if audio is None:
raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).")
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
return IO.NodeOutput(audio, ui=UI.PreviewAudio(audio, cls=cls))
save_flac = execute # TODO: remove
@ -822,6 +872,7 @@ class AudioExtension(ComfyExtension):
SaveAudio,
SaveAudioMP3,
SaveAudioOpus,
SaveAudioAdvanced,
LoadAudio,
PreviewAudio,
ConditioningStableAudio,

View File

@ -0,0 +1,108 @@
import torch
from typing_extensions import override
import comfy.model_management
import comfy.utils
import node_helpers
from comfy_api.latest import ComfyExtension, io
def _resize_long_edge(image, max_size, stride=16):
"""Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride`"""
h, w = image.shape[1], image.shape[2]
scale = min(max_size / max(h, w), 1.0)
nh = max(stride, round(h * scale / stride) * stride)
nw = max(stride, round(w * scale / stride) * stride)
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1)
class BerniniConditioning(io.ComfyNode):
"""Bernini in-context conditioning for a Wan2.2-A14B model.
Attaches the VAE-encoded source video / reference images to the conditioning
source video first, then each reference image
The task is inferred from which inputs are connected:
(nothing) -> t2v (text-to-video)
source_video -> v2v (video-to-video)
source_video + ref_images -> rv2v (reference-guided video editing)
ref_images only -> r2v (reference-to-video)
source_video + ref_video -> ads2v (insert image/video into video)
source_video is the edit base / canvas (resized to width x height).
reference_video is moving content to composite in.
Streams are ordered source_video, reference_video, then reference_images -> source_id (1, 2, 3, ...).
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="BerniniConditioning",
display_name="Bernini Conditioning",
category="model/conditioning/bernini",
description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)."
"Reference images injected as in-context tokens (r2v, rv2v) are encoded independently at their own native aspect ratio (long edge capped at ref_max_size)",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
io.Int.Input("length", default=81, min=1, max=8192, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("source_video", optional=True, tooltip=("Source video to edit or restyle (v2v, rv2v). Resized to width/height and trimmed to length.")),
io.Image.Input("reference_video", optional=True, tooltip=("Video to insert into the source video (ads2v).")),
io.Autogrow.Input("reference_images", optional=True,
template=io.Autogrow.TemplatePrefix(
input=io.Image.Input("reference_image", tooltip=("Reference image injected as an in-context token (r2v, rv2v).")),
prefix="reference_image_", min=0, max=8)),
io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=(
"Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, source_video=None, reference_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
# source_video (1), reference_video (2), reference_images (3, 4, ...).
context = []
if source_video is not None:
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
context.append(vae.encode(vid[:, :, :, :3]))
if reference_video is not None:
ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect
context.append(vae.encode(ref_vid[:, :, :, :3]))
# reference_images is an autogrow dict {reference_image_0: IMAGE, ...}; each slot is a
# separate stream at its own native aspect (a multi-image batch in one slot -> one stream per frame).
if reference_images:
for name in sorted(reference_images):
imgs = reference_images[name]
if imgs is None:
continue
for i in range(imgs.shape[0]):
img = _resize_long_edge(imgs[i:i + 1], ref_max_size) # native aspect per ref
context.append(vae.encode(img[:, :, :, :3]))
if context:
positive = node_helpers.conditioning_set_values(positive, {"context_latents": context})
negative = node_helpers.conditioning_set_values(negative, {"context_latents": context})
return io.NodeOutput(positive, negative, {"samples": latent})
class BerniniExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [BerniniConditioning,]
async def comfy_entrypoint() -> BerniniExtension:
return BerniniExtension()

View File

@ -36,15 +36,15 @@ class RemoveBackground(IO.ComfyNode):
category="image/background removal",
description="Generates a foreground mask to remove the background from an image using a background removal model.",
inputs=[
IO.Image.Input("image", tooltip="Input image to remove the background from"),
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask"),
IO.Image.Input("image", tooltip="Input image to remove the background from")
],
outputs=[
IO.Mask.Output("mask", tooltip="Generated foreground mask")
]
)
@classmethod
def execute(cls, image, bg_removal_model):
def execute(cls, bg_removal_model, image):
mask = bg_removal_model.encode_image(image)
return IO.NodeOutput(mask)

Some files were not shown because too many files have changed in this diff Show More