mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Merge branch 'Comfy-Org:master' into qwen-image-vae
This commit is contained in:
commit
ca686606ae
@ -1,5 +1,4 @@
|
||||
As of the time of writing this you need this driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
||||
As of the time of writing this you need a recent driver. Updating to the latest driver is recommended.
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
@ -7,9 +6,9 @@ If you have a AMD gpu:
|
||||
|
||||
run_amd_gpu.bat
|
||||
|
||||
If you have memory issues you can try disabling the smart memory management by running comfyui with:
|
||||
If you have memory issues you can try enabling the new dynamic memory management by running comfyui with:
|
||||
|
||||
run_amd_gpu_disable_smart_memory.bat
|
||||
run_amd_gpu_enable_dynamic_vram.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
|
||||
2
.github/workflows/check-line-endings.yml
vendored
2
.github/workflows/check-line-endings.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci')
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
CRLF_FOUND=false
|
||||
|
||||
22
README.md
22
README.md
@ -140,7 +140,7 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang
|
||||
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||
- Serves as the foundation for the desktop release
|
||||
|
||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||
2. **[Comfy Desktop](https://github.com/Comfy-Org/Comfy-Desktop)**
|
||||
- Builds a new release using the latest stable core version
|
||||
|
||||
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
|
||||
@ -309,7 +309,7 @@ After this you should have everything installed and can proceed to running Comfy
|
||||
|
||||
#### Apple Mac silicon
|
||||
|
||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
||||
You can install ComfyUI in Apple Mac silicon (M1, M2, M3 or M4) with any recent macOS version.
|
||||
|
||||
1. Install pytorch nightly. For instructions, read the [Accelerated PyTorch training on Mac](https://developer.apple.com/metal/pytorch/) Apple Developer guide (make sure to install the latest pytorch nightly).
|
||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux.
|
||||
@ -364,7 +364,7 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (implies `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
@ -382,11 +382,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
You can try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
# Notes
|
||||
|
||||
@ -462,16 +458,6 @@ To use the most up-to-date frontend version:
|
||||
|
||||
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
||||
```
|
||||
|
||||
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||
|
||||
# QA
|
||||
|
||||
### Which GPU should I buy for this?
|
||||
|
||||
39
alembic_db/versions/0004_drop_tag_type.py
Normal file
39
alembic_db/versions/0004_drop_tag_type.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Drop the vestigial tags.tag_type column.
|
||||
|
||||
tag_type was always "user" in practice — no code path ever set it to anything
|
||||
else (no system/seeded classification was ever wired up) and nothing queried it.
|
||||
The column, its index (ix_tags_tag_type), and the corresponding API field were
|
||||
dead weight, so they are removed.
|
||||
|
||||
Revision ID: 0004_drop_tag_type
|
||||
Revises: 0003_add_metadata_job_id
|
||||
Create Date: 2026-06-03
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0004_drop_tag_type"
|
||||
down_revision = "0003_add_metadata_job_id"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("tags") as batch_op:
|
||||
batch_op.drop_index("ix_tags_tag_type")
|
||||
batch_op.drop_column("tag_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("tags") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"tag_type",
|
||||
sa.String(length=32),
|
||||
nullable=False,
|
||||
server_default="user",
|
||||
)
|
||||
)
|
||||
batch_op.create_index("ix_tags_tag_type", ["tag_type"])
|
||||
@ -39,6 +39,7 @@ from app.assets.services import (
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.cursor import InvalidCursorError
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@ -174,7 +175,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
metadata=result.ref.system_metadata,
|
||||
job_id=result.ref.job_id,
|
||||
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||
prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility
|
||||
created_at=result.ref.created_at,
|
||||
updated_at=result.ref.updated_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
@ -211,24 +212,37 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
order_candidate = (q.order or "desc").lower()
|
||||
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
|
||||
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
try:
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after=q.after,
|
||||
)
|
||||
except InvalidCursorError as e:
|
||||
return _build_error_response(400, "INVALID_CURSOR", str(e))
|
||||
|
||||
summaries = [_build_asset_response(item) for item in result.items]
|
||||
|
||||
# has_more semantics differ by mode:
|
||||
# - cursor mode: a non-empty next_cursor means there are more results.
|
||||
# - offset mode: derived from total - (offset + page size).
|
||||
if q.after is not None:
|
||||
has_more = result.next_cursor is not None
|
||||
else:
|
||||
has_more = (q.offset + len(summaries)) < result.total
|
||||
|
||||
payload = schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=result.total,
|
||||
has_more=(q.offset + len(summaries)) < result.total,
|
||||
has_more=has_more,
|
||||
next_cursor=result.next_cursor,
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
@ -519,18 +533,14 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
@_require_assets_feature_enabled
|
||||
async def delete_asset_route(request: web.Request) -> web.Response:
|
||||
reference_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content_param = request.query.get("delete_content")
|
||||
delete_content = (
|
||||
False
|
||||
if delete_content_param is None
|
||||
else delete_content_param.lower() not in {"0", "false", "no"}
|
||||
)
|
||||
|
||||
try:
|
||||
# Deleting an asset is a soft delete of the reference; the underlying
|
||||
# content is preserved (it may be shared with other references).
|
||||
deleted = delete_asset_reference(
|
||||
reference_id=reference_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
delete_content_if_orphan=False,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
@ -575,8 +585,8 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
)
|
||||
|
||||
tags = [
|
||||
schemas_out.TagUsage(name=name, count=count, type=tag_type)
|
||||
for (name, tag_type, count) in rows
|
||||
schemas_out.TagUsage(name=name, count=count)
|
||||
for (name, count) in rows
|
||||
]
|
||||
payload = schemas_out.TagsList(
|
||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||
|
||||
@ -59,6 +59,11 @@ class ListAssetsQuery(BaseModel):
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
# Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination
|
||||
# is supported for sort values `created_at`, `updated_at`, `name`, `size`.
|
||||
# Supplying `after` together with `sort=last_access_time` returns
|
||||
# 400 INVALID_CURSOR; that sort only supports offset/limit.
|
||||
after: str | None = None
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
|
||||
"created_at"
|
||||
|
||||
@ -41,12 +41,13 @@ class AssetsList(BaseModel):
|
||||
assets: list[Asset]
|
||||
total: int
|
||||
has_more: bool
|
||||
# Opaque cursor for the next page. Omitted when there are no more results.
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
|
||||
@ -227,7 +227,6 @@ class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
|
||||
back_populates="tag",
|
||||
@ -240,7 +239,5 @@ class Tag(Base):
|
||||
overlaps="asset_reference_links,tag_links,tags,asset_reference",
|
||||
)
|
||||
|
||||
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
@ -266,9 +266,18 @@ def list_references_page(
|
||||
metadata_filter: dict | None = None,
|
||||
sort: str | None = None,
|
||||
order: str | None = None,
|
||||
after_cursor_value: object | None = None,
|
||||
after_cursor_id: str | None = None,
|
||||
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
|
||||
"""List references with pagination, filtering, and sorting.
|
||||
|
||||
When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses
|
||||
keyset pagination — ``offset`` is ignored and a WHERE clause selects rows
|
||||
strictly after the given ``(sort_col, id)`` position in the active sort
|
||||
direction. The cursor value must already be typed for the column
|
||||
(datetime for time sorts, int for size, str for name); the caller decodes
|
||||
the opaque cursor string and resolves to the typed value.
|
||||
|
||||
Returns (references, tag_map, total_count).
|
||||
"""
|
||||
base = (
|
||||
@ -297,9 +306,31 @@ def list_references_page(
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetReference.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
descending = order == "desc"
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
# Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor.
|
||||
# Equivalent to: sort_col <op> v OR (sort_col = v AND id <op> cursor_id).
|
||||
if after_cursor_value is not None and after_cursor_id is not None:
|
||||
if descending:
|
||||
keyset = sa.or_(
|
||||
sort_col < after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id),
|
||||
)
|
||||
else:
|
||||
keyset = sa.or_(
|
||||
sort_col > after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id),
|
||||
)
|
||||
base = base.where(keyset)
|
||||
|
||||
# Secondary ORDER BY id (matching the primary direction) gives the keyset
|
||||
# comparison a deterministic tiebreaker on duplicate sort_col values.
|
||||
id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc()
|
||||
sort_exp = sort_col.desc() if descending else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp, id_exp).limit(limit)
|
||||
if after_cursor_id is None:
|
||||
base = base.offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
|
||||
@ -55,13 +55,11 @@ def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def ensure_tags_exist(
|
||||
session: Session, names: Iterable[str], tag_type: str = "user"
|
||||
) -> None:
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str]) -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
rows = [{"name": n} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
@ -97,7 +95,7 @@ def set_reference_tags(
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
ensure_tags_exist(session, to_add)
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
@ -142,7 +140,7 @@ def add_tags_to_reference(
|
||||
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
ensure_tags_exist(session, norm)
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
@ -289,7 +287,6 @@ def list_tags_with_usage(
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
@ -331,7 +328,7 @@ def list_tags_with_usage(
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
rows_norm = [(name, int(count or 0)) for (name, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from app.assets.services.file_utils import (
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
@ -354,7 +355,7 @@ def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||
return 0
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
ensure_tags_exist(sess, tag_pool)
|
||||
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
|
||||
sess.commit()
|
||||
return result.inserted_refs
|
||||
@ -506,6 +507,10 @@ def enrich_asset(
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
dims = extract_image_dimensions(file_path, mime_type=mime_type)
|
||||
if dims:
|
||||
system_metadata.update(dims)
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
|
||||
if full_hash:
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
import contextlib
|
||||
import mimetypes
|
||||
import os
|
||||
from datetime import timezone
|
||||
from typing import Sequence
|
||||
|
||||
from app.assets.services.cursor import (
|
||||
CursorPayload,
|
||||
InvalidCursorError,
|
||||
decode_cursor,
|
||||
decode_cursor_int,
|
||||
decode_cursor_time,
|
||||
encode_cursor,
|
||||
encode_cursor_from_time,
|
||||
)
|
||||
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
@ -149,6 +160,16 @@ def delete_asset_reference(
|
||||
owner_id: str,
|
||||
delete_content_if_orphan: bool = True,
|
||||
) -> bool:
|
||||
"""Delete an asset reference.
|
||||
|
||||
With ``delete_content_if_orphan=False`` (a soft delete), the reference is
|
||||
hidden and the underlying content is preserved. With ``True``, the content
|
||||
is also removed once it becomes orphaned.
|
||||
|
||||
Note: the public DELETE /api/assets/{id} endpoint always soft-deletes
|
||||
(passes ``False``); the orphan-reclamation path is intentionally
|
||||
internal-only, retained for a future GC/admin caller.
|
||||
"""
|
||||
with create_session() as session:
|
||||
if not delete_content_if_orphan:
|
||||
# Soft delete: mark the reference as deleted but keep everything
|
||||
@ -242,6 +263,11 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None:
|
||||
return extract_asset_data(asset)
|
||||
|
||||
|
||||
# Sort fields that support cursor pagination. `last_access_time` is not
|
||||
# in this list — it falls back to offset/limit.
|
||||
_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size")
|
||||
|
||||
|
||||
def list_assets_page(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@ -252,7 +278,39 @@ def list_assets_page(
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
after: str | None = None,
|
||||
) -> ListAssetsResult:
|
||||
"""List assets with optional cursor pagination.
|
||||
|
||||
When ``after`` is supplied it overrides ``offset``. The cursor's sort field
|
||||
must match ``sort`` and be in the cursor-supported allowlist; mismatches
|
||||
raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR.
|
||||
"""
|
||||
cursor_value: object | None = None
|
||||
cursor_id: str | None = None
|
||||
# Mint next_cursor on every page where the sort is cursor-supported, not
|
||||
# only when the request itself arrived with a cursor. Otherwise a first
|
||||
# request (no `after`) returns next_cursor=None and the client can never
|
||||
# enter cursor mode.
|
||||
mint_cursor = sort in _CURSOR_SORT_FIELDS
|
||||
|
||||
if after is not None:
|
||||
if sort not in _CURSOR_SORT_FIELDS:
|
||||
raise InvalidCursorError(
|
||||
f"cursor pagination is not supported for sort={sort!r}"
|
||||
)
|
||||
payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order)
|
||||
if payload.sort_field != sort:
|
||||
raise InvalidCursorError(
|
||||
f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}"
|
||||
)
|
||||
cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id
|
||||
|
||||
# Over-fetch by one row so we can distinguish "exactly `limit` rows total
|
||||
# remaining" from "more rows past this page" without a second query. Drop
|
||||
# the sentinel before returning.
|
||||
fetch_limit = limit + 1 if mint_cursor else limit
|
||||
|
||||
with create_session() as session:
|
||||
refs, tag_map, total = list_references_page(
|
||||
session,
|
||||
@ -261,12 +319,22 @@ def list_assets_page(
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
limit=fetch_limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after_cursor_value=cursor_value,
|
||||
after_cursor_id=cursor_id,
|
||||
)
|
||||
|
||||
next_cursor: str | None = None
|
||||
if mint_cursor and len(refs) > limit:
|
||||
# There's at least one more row past this page — mint a cursor from
|
||||
# the last row of the page (i.e. index `limit - 1`, since we
|
||||
# over-fetched), and drop the sentinel.
|
||||
next_cursor = _encode_next_cursor(refs[limit - 1], sort, order)
|
||||
refs = refs[:limit]
|
||||
|
||||
items: list[AssetSummaryData] = []
|
||||
for ref in refs:
|
||||
items.append(
|
||||
@ -277,7 +345,39 @@ def list_assets_page(
|
||||
)
|
||||
)
|
||||
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
return ListAssetsResult(items=items, total=total, next_cursor=next_cursor)
|
||||
|
||||
|
||||
def _resolve_cursor_value(payload: CursorPayload) -> object:
|
||||
"""Map a decoded cursor payload to a column-typed Python value."""
|
||||
if payload.sort_field in ("created_at", "updated_at"):
|
||||
# DB stores naive UTC; strip tzinfo so the comparison binds against a
|
||||
# `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift.
|
||||
return decode_cursor_time(payload).replace(tzinfo=None)
|
||||
if payload.sort_field == "size":
|
||||
return decode_cursor_int(payload)
|
||||
return payload.value # name, str-typed
|
||||
|
||||
|
||||
def _encode_next_cursor(ref, sort: str, order: str) -> str | None:
|
||||
"""Mint a cursor pointing at *ref* for the given sort dimension.
|
||||
|
||||
Returns None when the boundary row carries a NULL sort value (e.g. an asset
|
||||
record whose size_bytes hasn't been backfilled). Continuing pagination
|
||||
across a NULL boundary is undefined under keyset ordering — better to
|
||||
truncate cleanly here than to mint a cursor that mis-positions.
|
||||
"""
|
||||
if sort == "name":
|
||||
return encode_cursor("name", ref.name, ref.id, order=order)
|
||||
if sort == "size":
|
||||
if ref.asset is None or ref.asset.size_bytes is None:
|
||||
return None
|
||||
return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order)
|
||||
# created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding.
|
||||
value = ref.created_at if sort == "created_at" else ref.updated_at
|
||||
if value is None:
|
||||
return None
|
||||
return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order)
|
||||
|
||||
|
||||
def resolve_hash_to_path(
|
||||
|
||||
213
app/assets/services/cursor.py
Normal file
213
app/assets/services/cursor.py
Normal file
@ -0,0 +1,213 @@
|
||||
"""Opaque keyset-pagination cursor for /api/assets.
|
||||
|
||||
Payload JSON uses short keys to keep the encoded length small:
|
||||
|
||||
{"s": <sort_field>, "v": <value>, "id": <id>, "o": <order>}
|
||||
|
||||
The `o` key binds the cursor to the sort direction it was minted under,
|
||||
so replaying a `desc` cursor against an `asc` request fails with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
`o` is mandatory on every payload — a cursor without it is rejected as
|
||||
malformed.
|
||||
|
||||
Encoding is base64url with no padding. Cursors are opaque tokens: the
|
||||
payload format is internal to this server, and clients must treat a
|
||||
cursor as a black box handed back via `next_cursor`. No byte-level
|
||||
compatibility with any other implementation is required.
|
||||
|
||||
Time values are serialized as Unix microseconds (UTC) — microsecond
|
||||
precision is sufficient to round-trip the timestamps stored by the
|
||||
database without rounding rows in the same millisecond bucket.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable, Optional
|
||||
|
||||
|
||||
class InvalidCursorError(ValueError):
|
||||
"""Raised on a malformed, oversized, or unsupported-sort-field cursor.
|
||||
|
||||
Map to a 400 response with code ``INVALID_CURSOR`` at the handler.
|
||||
"""
|
||||
|
||||
|
||||
# Wire-format length caps. Cursors are user-controlled, so caps protect the
|
||||
# decode path from oversized allocations and downstream SQL predicates from
|
||||
# unbounded strings.
|
||||
#
|
||||
# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max
|
||||
# (`String(512)`) — otherwise a long-named asset would mint a cursor the same
|
||||
# server then refuses on the next request.
|
||||
#
|
||||
# MAX_ENCODED_CURSOR_LENGTH is the decode-path guard, sized comfortably above
|
||||
# the largest cursor the per-field caps can produce. Worst case is value + id
|
||||
# at their caps with every character JSON-escaping to the six-byte `\uXXXX`
|
||||
# form (control characters), which is ~5.2 KB once base64url-encoded. At 8192
|
||||
# the encoder can never mint a cursor that exceeds it, so a freshly minted
|
||||
# cursor always decodes on the next request and there is no user-visible
|
||||
# "cursor too long" failure.
|
||||
MAX_ENCODED_CURSOR_LENGTH = 8192
|
||||
MAX_CURSOR_VALUE_LENGTH = 512
|
||||
MAX_CURSOR_ID_LENGTH = 128
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CursorPayload:
|
||||
sort_field: str
|
||||
value: str
|
||||
id: str
|
||||
order: str
|
||||
|
||||
|
||||
_VALID_ORDERS = ("asc", "desc")
|
||||
|
||||
|
||||
def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str:
|
||||
"""Encode a cursor payload as a base64url (no-padding) string.
|
||||
|
||||
`order` binds the cursor to the sort direction it was minted under so a
|
||||
later request with a flipped `order` query parameter is rejected with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
"""
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}")
|
||||
# Symmetric input validation: the encoder must reject anything the
|
||||
# decoder rejects, or the same server will mint cursors it then 400s on
|
||||
# the next request.
|
||||
if not id:
|
||||
raise InvalidCursorError("id must be non-empty")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
payload = {"s": sort_field, "v": value, "id": id, "o": order}
|
||||
raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
|
||||
# No mint-time length guard is needed: the per-field caps above bound the
|
||||
# encoded length well below MAX_ENCODED_CURSOR_LENGTH (see its definition),
|
||||
# so the encoder can never produce a cursor the decode path would reject.
|
||||
return base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str:
|
||||
"""Encode a time-typed cursor at Unix microsecond precision.
|
||||
|
||||
Accepts an aware datetime (any timezone) and normalizes to UTC. Naive
|
||||
datetimes are rejected so callers can't accidentally encode the local
|
||||
wall-clock value of a UTC-stored timestamp.
|
||||
"""
|
||||
if t.tzinfo is None:
|
||||
raise ValueError("encode_cursor_from_time requires an aware datetime")
|
||||
micros = _datetime_to_unix_micros(t.astimezone(timezone.utc))
|
||||
return encode_cursor(sort_field, str(micros), id, order=order)
|
||||
|
||||
|
||||
def decode_cursor(
|
||||
cursor: str,
|
||||
allowed_sort_fields: Iterable[str],
|
||||
expected_order: str | None = None,
|
||||
) -> CursorPayload:
|
||||
"""Parse an opaque cursor.
|
||||
|
||||
``allowed_sort_fields`` is the endpoint's accepted sort-field list — a
|
||||
cursor carrying a field outside this set is rejected so a cursor minted
|
||||
for one column can't be replayed against another (e.g. a ``created_at``
|
||||
timestamp string compared against a ``name`` column).
|
||||
|
||||
``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the
|
||||
payload's ``o`` field. ``o`` is required on every payload; a cursor
|
||||
missing it is rejected as malformed.
|
||||
|
||||
Passing no allowed fields rejects every cursor.
|
||||
"""
|
||||
if len(cursor) > MAX_ENCODED_CURSOR_LENGTH:
|
||||
raise InvalidCursorError("cursor exceeds maximum length")
|
||||
|
||||
try:
|
||||
# urlsafe_b64decode requires correct padding; we strip on encode, so
|
||||
# restore the trailing '=' pad here.
|
||||
padding = "=" * (-len(cursor) % 4)
|
||||
raw = base64.urlsafe_b64decode(cursor + padding)
|
||||
except (ValueError, base64.binascii.Error) as e:
|
||||
raise InvalidCursorError(f"encoding: {e}") from e
|
||||
|
||||
try:
|
||||
decoded = json.loads(raw)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
raise InvalidCursorError(f"payload: {e}") from e
|
||||
|
||||
if not isinstance(decoded, dict):
|
||||
raise InvalidCursorError("payload: expected object")
|
||||
|
||||
sort_field = decoded.get("s")
|
||||
value = decoded.get("v")
|
||||
id = decoded.get("id")
|
||||
order = decoded.get("o")
|
||||
|
||||
if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str):
|
||||
raise InvalidCursorError("payload: missing or non-string s/v/id")
|
||||
|
||||
if id == "":
|
||||
raise InvalidCursorError("missing id")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
|
||||
if sort_field not in allowed_sort_fields:
|
||||
raise InvalidCursorError(f"unsupported sort field {sort_field!r}")
|
||||
|
||||
if not isinstance(order, str):
|
||||
raise InvalidCursorError("missing or non-string o")
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"unsupported order {order!r}")
|
||||
if expected_order is not None and order != expected_order:
|
||||
raise InvalidCursorError(
|
||||
f"cursor order {order!r} does not match request order {expected_order!r}"
|
||||
)
|
||||
|
||||
return CursorPayload(sort_field=sort_field, value=value, id=id, order=order)
|
||||
|
||||
|
||||
def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime:
|
||||
"""Parse a time-typed cursor value as Unix microseconds, returning UTC."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
micros = int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e
|
||||
try:
|
||||
return _unix_micros_to_datetime(micros)
|
||||
except (OverflowError, OSError, ValueError) as e:
|
||||
# Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up
|
||||
# in fromtimestamp / datetime construction. Map to 400, not 500.
|
||||
raise InvalidCursorError(f"value is out of representable range: {e}") from e
|
||||
|
||||
|
||||
def decode_cursor_int(payload: Optional[CursorPayload]) -> int:
|
||||
"""Parse a cursor value as a base-10 integer."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
return int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid integer: {e}") from e
|
||||
|
||||
|
||||
_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _datetime_to_unix_micros(t: datetime) -> int:
|
||||
"""Convert an aware UTC datetime to Unix microseconds (integer math)."""
|
||||
delta = t - _EPOCH
|
||||
return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds
|
||||
|
||||
|
||||
def _unix_micros_to_datetime(micros: int) -> datetime:
|
||||
"""Convert Unix microseconds to a UTC datetime, preserving precision."""
|
||||
seconds, micro_remainder = divmod(micros, 1_000_000)
|
||||
return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder)
|
||||
63
app/assets/services/image_dimensions.py
Normal file
63
app/assets/services/image_dimensions.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Image dimension extraction for asset ingest.
|
||||
|
||||
Reads only the image header via Pillow to capture width/height cheaply,
|
||||
without a full pixel decode. Returns a metadata dict suitable for merging
|
||||
into ``AssetReference.system_metadata``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_image_dimensions(
|
||||
file_path: str, mime_type: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Extract image dimensions for the file at ``file_path``.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to a file on disk.
|
||||
mime_type: Optional MIME type hint. When provided and not prefixed
|
||||
with ``image/``, extraction is skipped without touching the file.
|
||||
|
||||
Returns:
|
||||
``{"kind": "image", "width": W, "height": H}`` when the file is a
|
||||
recognizable image with positive dimensions, otherwise ``None``.
|
||||
|
||||
The dict shape is intended to be merged into ``system_metadata`` so the
|
||||
asset response surfaces ``metadata.kind`` plus dimension fields for image
|
||||
assets. Forward-compatible: future media kinds (e.g. ``"video"`` with
|
||||
duration/fps) can extend this shape without schema changes.
|
||||
"""
|
||||
if mime_type is not None and not mime_type.startswith("image/"):
|
||||
return None
|
||||
|
||||
try:
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"Pillow not available; skipping image dimension extraction for %s",
|
||||
file_path,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
width, height = img.size
|
||||
except (OSError, UnidentifiedImageError, ValueError) as exc:
|
||||
logger.debug(
|
||||
"Failed to read image dimensions from %s: %s", file_path, exc
|
||||
)
|
||||
return None
|
||||
|
||||
if (
|
||||
not isinstance(width, int)
|
||||
or not isinstance(height, int)
|
||||
or width <= 0
|
||||
or height <= 0
|
||||
):
|
||||
return None
|
||||
|
||||
return {"kind": "image", "width": width, "height": height}
|
||||
@ -17,9 +17,11 @@ from app.assets.database.queries import (
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
list_references_by_asset_id,
|
||||
reference_exists,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_system_metadata,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
@ -29,6 +31,7 @@ from app.assets.database.queries import (
|
||||
from app.assets.helpers import get_utc_now, normalize_tags
|
||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_name_and_tags_from_asset_path,
|
||||
@ -118,6 +121,14 @@ def _ingest_file_from_path(
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
_maybe_store_image_dimensions(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
file_path=locator,
|
||||
mime_type=mime_type,
|
||||
current_system_metadata=ref.system_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
@ -288,6 +299,13 @@ def _register_existing_asset(
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
_backfill_image_dimensions_from_siblings(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
new_reference_id=ref.id,
|
||||
current_system_metadata=ref.system_metadata,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_reference_tags(
|
||||
session,
|
||||
@ -334,6 +352,87 @@ def _update_metadata_with_filename(
|
||||
)
|
||||
|
||||
|
||||
_IMAGE_DIMENSION_KEYS = ("kind", "width", "height")
|
||||
|
||||
|
||||
def _maybe_store_image_dimensions(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
file_path: str,
|
||||
mime_type: str | None,
|
||||
current_system_metadata: dict | None,
|
||||
) -> None:
|
||||
"""Populate ``kind``/``width``/``height`` on system_metadata for image refs.
|
||||
|
||||
Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written
|
||||
safetensors metadata, download provenance) are preserved by merge.
|
||||
"""
|
||||
if not mime_type or not mime_type.startswith("image/"):
|
||||
return
|
||||
|
||||
dims = extract_image_dimensions(file_path, mime_type=mime_type)
|
||||
if not dims:
|
||||
return
|
||||
|
||||
current = current_system_metadata or {}
|
||||
merged = dict(current)
|
||||
merged.update(dims)
|
||||
if merged != current:
|
||||
set_reference_system_metadata(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
system_metadata=merged,
|
||||
)
|
||||
|
||||
|
||||
def _backfill_image_dimensions_from_siblings(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
new_reference_id: str,
|
||||
current_system_metadata: dict | None,
|
||||
) -> None:
|
||||
"""Copy image dimension keys from any sibling reference of the same asset.
|
||||
|
||||
The from-hash path doesn't read the file bytes, so dimensions can't be
|
||||
extracted there directly. When another reference of the same asset already
|
||||
carries image dimensions, copy them onto the new reference so consumers
|
||||
see consistent metadata regardless of how the asset was registered.
|
||||
|
||||
Best-effort: missing siblings, non-image siblings, or absent dimension
|
||||
keys leave the target reference unchanged.
|
||||
"""
|
||||
current = current_system_metadata or {}
|
||||
if current.get("kind") == "image" and "width" in current and "height" in current:
|
||||
return
|
||||
|
||||
for sibling in list_references_by_asset_id(session, asset_id):
|
||||
if sibling.id == new_reference_id:
|
||||
continue
|
||||
meta = sibling.system_metadata or {}
|
||||
if meta.get("kind") != "image":
|
||||
continue
|
||||
width = meta.get("width")
|
||||
height = meta.get("height")
|
||||
if (
|
||||
type(width) is not int
|
||||
or type(height) is not int
|
||||
or width <= 0
|
||||
or height <= 0
|
||||
):
|
||||
continue
|
||||
merged = dict(current)
|
||||
merged["kind"] = "image"
|
||||
merged["width"] = width
|
||||
merged["height"] = height
|
||||
if merged != current:
|
||||
set_reference_system_metadata(
|
||||
session,
|
||||
reference_id=new_reference_id,
|
||||
system_metadata=merged,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _sanitize_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
return n if n else fallback
|
||||
|
||||
@ -56,7 +56,6 @@ class IngestResult:
|
||||
|
||||
class TagUsage(NamedTuple):
|
||||
name: str
|
||||
tag_type: str
|
||||
count: int
|
||||
|
||||
|
||||
@ -71,6 +70,7 @@ class AssetSummaryData:
|
||||
class ListAssetsResult:
|
||||
items: list[AssetSummaryData]
|
||||
total: int
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@ -75,7 +75,7 @@ def list_tags(
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
return [TagUsage(name, count) for name, count in rows], total
|
||||
|
||||
|
||||
def list_tag_histogram(
|
||||
|
||||
4191
blueprints/Character Replacement (SCAIL-2 Base).json
Normal file
4191
blueprints/Character Replacement (SCAIL-2 Base).json
Normal file
File diff suppressed because it is too large
Load Diff
4461
blueprints/Character Replacement (SCAIL-2 Extend).json
Normal file
4461
blueprints/Character Replacement (SCAIL-2 Extend).json
Normal file
File diff suppressed because it is too large
Load Diff
569
blueprints/Image Depth Estimation (Depth Anything 3).json
Normal file
569
blueprints/Image Depth Estimation (Depth Anything 3).json
Normal file
@ -0,0 +1,569 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 89,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 89,
|
||||
"type": "85e595bd-af9e-40ee-85c5-b98bb15da47a",
|
||||
"pos": [
|
||||
320,
|
||||
520
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
360
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "resolution",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "resolution"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "resize_method",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "resize_method"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "output_type",
|
||||
"name": "output",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "output"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "output_normalization",
|
||||
"name": "output.normalization",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "output.normalization"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "apply_sky_clip",
|
||||
"name": "output.apply_sky_clip",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "output.apply_sky_clip"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"87",
|
||||
"resolution"
|
||||
],
|
||||
[
|
||||
"87",
|
||||
"resize_method"
|
||||
],
|
||||
[
|
||||
"86",
|
||||
"output"
|
||||
],
|
||||
[
|
||||
"86",
|
||||
"output.normalization"
|
||||
],
|
||||
[
|
||||
"86",
|
||||
"output.apply_sky_clip"
|
||||
],
|
||||
[
|
||||
"88",
|
||||
"model_name"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.24.0"
|
||||
},
|
||||
"widgets_values": [],
|
||||
"title": "Image Depth Estimation (Depth Anything 3)"
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "85e595bd-af9e-40ee-85c5-b98bb15da47a",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 4,
|
||||
"lastNodeId": 89,
|
||||
"lastLinkId": 109,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 2,
|
||||
"config": {},
|
||||
"name": "Image Depth Estimation (Depth Anything 3)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
400,
|
||||
90,
|
||||
166.998046875,
|
||||
188
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
1250,
|
||||
146,
|
||||
128,
|
||||
68
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "43cf3118-495a-487d-8eb3-a17c7e92f64f",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
19
|
||||
],
|
||||
"localized_name": "image",
|
||||
"pos": [
|
||||
542.998046875,
|
||||
114
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "1089a0a1-6db1-45a8-84b0-0bfdc2ed920a",
|
||||
"name": "resolution",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
22
|
||||
],
|
||||
"pos": [
|
||||
542.998046875,
|
||||
134
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "25fb64ac-26d5-466d-995b-6d51b9afa2c4",
|
||||
"name": "resize_method",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
23
|
||||
],
|
||||
"pos": [
|
||||
542.998046875,
|
||||
154
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "8acafb7c-6c8b-46b3-9d74-c563498a3af1",
|
||||
"name": "output",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"linkIds": [
|
||||
24
|
||||
],
|
||||
"label": "output_type",
|
||||
"pos": [
|
||||
542.998046875,
|
||||
174
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "1da5009b-4648-43e8-a257-16426630cf22",
|
||||
"name": "output.normalization",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
25
|
||||
],
|
||||
"label": "output_normalization",
|
||||
"pos": [
|
||||
542.998046875,
|
||||
194
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "fd7edb33-5fb1-4538-a411-26e5039a9321",
|
||||
"name": "output.apply_sky_clip",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
26
|
||||
],
|
||||
"label": "apply_sky_clip",
|
||||
"pos": [
|
||||
542.998046875,
|
||||
214
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "b5be4c8a-b833-4f1e-8c94-3ed1dd722190",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
106
|
||||
],
|
||||
"pos": [
|
||||
542.998046875,
|
||||
234
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "478ab537-63bc-4d74-a9f0-c975f550880f",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
7
|
||||
],
|
||||
"localized_name": "IMAGE",
|
||||
"pos": [
|
||||
1274,
|
||||
170
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 86,
|
||||
"type": "DA3Render",
|
||||
"pos": [
|
||||
800,
|
||||
310
|
||||
],
|
||||
"size": [
|
||||
380,
|
||||
130
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "da3_geometry",
|
||||
"name": "da3_geometry",
|
||||
"type": "DA3_GEOMETRY",
|
||||
"link": 12
|
||||
},
|
||||
{
|
||||
"localized_name": "output",
|
||||
"name": "output",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "output"
|
||||
},
|
||||
"link": 24
|
||||
},
|
||||
{
|
||||
"localized_name": "output.normalization",
|
||||
"name": "output.normalization",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "output.normalization"
|
||||
},
|
||||
"link": 25
|
||||
},
|
||||
{
|
||||
"localized_name": "output.apply_sky_clip",
|
||||
"name": "output.apply_sky_clip",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "output.apply_sky_clip"
|
||||
},
|
||||
"link": 26
|
||||
},
|
||||
{
|
||||
"name": "geometry",
|
||||
"type": "DA3_GEOMETRY",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
7
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "DA3Render",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0"
|
||||
},
|
||||
"widgets_values": [
|
||||
"depth",
|
||||
"v2_style",
|
||||
false
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 87,
|
||||
"type": "DA3Inference",
|
||||
"pos": [
|
||||
800,
|
||||
50
|
||||
],
|
||||
"size": [
|
||||
390,
|
||||
130
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "da3_model",
|
||||
"name": "da3_model",
|
||||
"type": "DA3_MODEL",
|
||||
"link": 107
|
||||
},
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 19
|
||||
},
|
||||
{
|
||||
"localized_name": "resolution",
|
||||
"name": "resolution",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "resolution"
|
||||
},
|
||||
"link": 22
|
||||
},
|
||||
{
|
||||
"localized_name": "resize_method",
|
||||
"name": "resize_method",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "resize_method"
|
||||
},
|
||||
"link": 23
|
||||
},
|
||||
{
|
||||
"localized_name": "mode",
|
||||
"name": "mode",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "mode"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "da3_geometry",
|
||||
"name": "da3_geometry",
|
||||
"type": "DA3_GEOMETRY",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
12
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "DA3Inference",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0"
|
||||
},
|
||||
"widgets_values": [
|
||||
504,
|
||||
"upper_bound_resize",
|
||||
"mono"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 88,
|
||||
"type": "LoadDA3Model",
|
||||
"pos": [
|
||||
810,
|
||||
-160
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
140
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "model_name",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": 106
|
||||
},
|
||||
{
|
||||
"localized_name": "weight_dtype",
|
||||
"name": "weight_dtype",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "weight_dtype"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "DA3_MODEL",
|
||||
"name": "DA3_MODEL",
|
||||
"type": "DA3_MODEL",
|
||||
"links": [
|
||||
107
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "LoadDA3Model",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.24.0",
|
||||
"models": [
|
||||
{
|
||||
"name": "depth_anything_3_mono_large.safetensors",
|
||||
"url": "https://huggingface.co/Comfy-Org/Depth-Anything-3/resolve/main/geometry_estimation/depth_anything_3_mono_large.safetensors",
|
||||
"directory": "geometry_estimation"
|
||||
}
|
||||
]
|
||||
},
|
||||
"widgets_values": [
|
||||
"depth_anything_3_mono_large.safetensors",
|
||||
"default"
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 12,
|
||||
"origin_id": 87,
|
||||
"origin_slot": 0,
|
||||
"target_id": 86,
|
||||
"target_slot": 0,
|
||||
"type": "DA3_GEOMETRY"
|
||||
},
|
||||
{
|
||||
"id": 19,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 87,
|
||||
"target_slot": 1,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"origin_id": 86,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 22,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 87,
|
||||
"target_slot": 2,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 23,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 2,
|
||||
"target_id": 87,
|
||||
"target_slot": 3,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 24,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 3,
|
||||
"target_id": 86,
|
||||
"target_slot": 1,
|
||||
"type": "COMFY_DYNAMICCOMBO_V3"
|
||||
},
|
||||
{
|
||||
"id": 25,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 4,
|
||||
"target_id": 86,
|
||||
"target_slot": 2,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 26,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 5,
|
||||
"target_id": 86,
|
||||
"target_slot": 3,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 106,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 6,
|
||||
"target_id": 88,
|
||||
"target_slot": 0,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 107,
|
||||
"origin_id": 88,
|
||||
"origin_slot": 0,
|
||||
"target_id": 87,
|
||||
"target_slot": 0,
|
||||
"type": "DA3_MODEL"
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Conditioning & Preprocessors/Depth",
|
||||
"description": "This subgraph takes an input image and produces a depth map using the Depth Anything 3 model, which recovers spatially consistent geometry from any number of views. It is ideal for single or multi-view images, videos, and 3D scenes where accurate depth estimation is needed for tasks like SLAM, novel view synthesis, or spatial perception. The model uses a plain transformer backbone and supports both monocular and multi-view inputs without."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {
|
||||
"BlueprintDescription": "This subgraph takes an input image and produces a depth map using the Depth Anything 3 model, which recovers spatially consistent geometry from any number of views. It is ideal for single or multi-view images, videos, and 3D scenes where accurate depth estimation is needed for tasks like SLAM, novel view synthesis, or spatial perception. The model uses a plain transformer backbone and supports both monocular and multi-view inputs without."
|
||||
}
|
||||
}
|
||||
3549
blueprints/Image Edit (Bernini-R).json
Normal file
3549
blueprints/Image Edit (Bernini-R).json
Normal file
File diff suppressed because it is too large
Load Diff
1983
blueprints/Image to Gaussian Splat (TripoSplat).json
Normal file
1983
blueprints/Image to Gaussian Splat (TripoSplat).json
Normal file
File diff suppressed because it is too large
Load Diff
1088
blueprints/Text to Image (Anima Base 1.0).json
Normal file
1088
blueprints/Text to Image (Anima Base 1.0).json
Normal file
File diff suppressed because it is too large
Load Diff
@ -1077,9 +1077,12 @@
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Image generation and editing/Text to image"
|
||||
"category": "Image generation and editing/Text to image",
|
||||
"description": "This subgraph converts text prompts into non-photorealistic illustrations using a 2-billion-parameter model optimized for anime and artistic styles. It is ideal for generating concept art, character designs, or stylized illustrations where photorealism is not required. The model excels with anime and artistic content but performs poorly on realistic subjects."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {}
|
||||
"extra": {
|
||||
"BlueprintDescription": "This subgraph converts text prompts into non-photorealistic illustrations using a 2-billion-parameter model optimized for anime and artistic styles. It is ideal for generating concept art, character designs, or stylized illustrations where photorealism is not required. The model excels with anime and artistic content but performs poorly on realistic subjects."
|
||||
}
|
||||
}
|
||||
2473
blueprints/Text to Image (Ideogram v4).json
Normal file
2473
blueprints/Text to Image (Ideogram v4).json
Normal file
File diff suppressed because it is too large
Load Diff
825
blueprints/Video Depth Estimation (Depth Anything 3).json
Normal file
825
blueprints/Video Depth Estimation (Depth Anything 3).json
Normal file
@ -0,0 +1,825 @@
|
||||
{
|
||||
"revision": 0,
|
||||
"last_node_id": 97,
|
||||
"last_link_id": 0,
|
||||
"nodes": [
|
||||
{
|
||||
"id": 97,
|
||||
"type": "253ec5ca-8333-4ddf-a036-9fc0923651b9",
|
||||
"pos": [
|
||||
410,
|
||||
500
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
400
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"name": "video",
|
||||
"type": "VIDEO",
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "start_time",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "start_time"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "duration",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "duration"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "resolution",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "resolution"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "resize_method",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "resize_method"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "output_type",
|
||||
"name": "output",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "output"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"label": "normalization",
|
||||
"name": "output.normalization",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "output.normalization"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "output.apply_sky_clip",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "output.apply_sky_clip"
|
||||
},
|
||||
"link": null
|
||||
},
|
||||
{
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"name": "audio",
|
||||
"type": "AUDIO",
|
||||
"links": []
|
||||
},
|
||||
{
|
||||
"name": "fps",
|
||||
"type": "FLOAT",
|
||||
"links": []
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"proxyWidgets": [
|
||||
[
|
||||
"96",
|
||||
"start_time"
|
||||
],
|
||||
[
|
||||
"96",
|
||||
"duration"
|
||||
],
|
||||
[
|
||||
"93",
|
||||
"resolution"
|
||||
],
|
||||
[
|
||||
"93",
|
||||
"resize_method"
|
||||
],
|
||||
[
|
||||
"92",
|
||||
"output"
|
||||
],
|
||||
[
|
||||
"92",
|
||||
"output.normalization"
|
||||
],
|
||||
[
|
||||
"92",
|
||||
"output.apply_sky_clip"
|
||||
],
|
||||
[
|
||||
"94",
|
||||
"model_name"
|
||||
]
|
||||
],
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.24.0"
|
||||
},
|
||||
"widgets_values": [],
|
||||
"title": "Video Depth Estimation (Depth Anything 3)"
|
||||
}
|
||||
],
|
||||
"links": [],
|
||||
"version": 0.4,
|
||||
"definitions": {
|
||||
"subgraphs": [
|
||||
{
|
||||
"id": "253ec5ca-8333-4ddf-a036-9fc0923651b9",
|
||||
"version": 1,
|
||||
"state": {
|
||||
"lastGroupId": 4,
|
||||
"lastNodeId": 97,
|
||||
"lastLinkId": 129,
|
||||
"lastRerouteId": 0
|
||||
},
|
||||
"revision": 2,
|
||||
"config": {},
|
||||
"name": "Video Depth Estimation (Depth Anything 3)",
|
||||
"inputNode": {
|
||||
"id": -10,
|
||||
"bounding": [
|
||||
-230,
|
||||
130,
|
||||
167.912109375,
|
||||
228
|
||||
]
|
||||
},
|
||||
"outputNode": {
|
||||
"id": -20,
|
||||
"bounding": [
|
||||
1520,
|
||||
140,
|
||||
128,
|
||||
108
|
||||
]
|
||||
},
|
||||
"inputs": [
|
||||
{
|
||||
"id": "698c28c6-cf92-4039-8b39-f3062868ea7c",
|
||||
"name": "video",
|
||||
"type": "VIDEO",
|
||||
"linkIds": [
|
||||
119
|
||||
],
|
||||
"pos": [
|
||||
-86.087890625,
|
||||
154
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "97a1f63e-1585-4a40-9dec-e2700120d84a",
|
||||
"name": "start_time",
|
||||
"type": "FLOAT",
|
||||
"linkIds": [
|
||||
121
|
||||
],
|
||||
"pos": [
|
||||
-86.087890625,
|
||||
174
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4dbbd3b3-c5ee-4a56-a0d3-3268d3b2fd64",
|
||||
"name": "duration",
|
||||
"type": "FLOAT",
|
||||
"linkIds": [
|
||||
122
|
||||
],
|
||||
"pos": [
|
||||
-86.087890625,
|
||||
194
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "16f55101-f99d-4c0c-bebf-c3b31c54f13e",
|
||||
"name": "resolution",
|
||||
"type": "INT",
|
||||
"linkIds": [
|
||||
124
|
||||
],
|
||||
"pos": [
|
||||
-86.087890625,
|
||||
214
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "d9cd7693-4bb3-4ed7-9a75-276b997abcd9",
|
||||
"name": "resize_method",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
125
|
||||
],
|
||||
"pos": [
|
||||
-86.087890625,
|
||||
234
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "a6e90532-323b-462e-ba9c-1672384d5b31",
|
||||
"name": "output",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"linkIds": [
|
||||
126
|
||||
],
|
||||
"label": "output_type",
|
||||
"pos": [
|
||||
-86.087890625,
|
||||
254
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "69e6aeef-437d-4fde-b2fc-d5ab9369238d",
|
||||
"name": "output.normalization",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
127
|
||||
],
|
||||
"label": "normalization",
|
||||
"pos": [
|
||||
-86.087890625,
|
||||
274
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "73206f72-f89a-4698-885e-5d9277df2998",
|
||||
"name": "output.apply_sky_clip",
|
||||
"type": "BOOLEAN",
|
||||
"linkIds": [
|
||||
128
|
||||
],
|
||||
"pos": [
|
||||
-86.087890625,
|
||||
294
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "dddbc7fc-9431-448a-9ed3-9aa62404288b",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"linkIds": [
|
||||
129
|
||||
],
|
||||
"pos": [
|
||||
-86.087890625,
|
||||
314
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"id": "478ab537-63bc-4d74-a9f0-c975f550880f",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"linkIds": [
|
||||
7
|
||||
],
|
||||
"localized_name": "IMAGE",
|
||||
"pos": [
|
||||
1544,
|
||||
164
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "cdaf037e-79bc-4a94-b06c-0fd32e76f615",
|
||||
"name": "audio",
|
||||
"type": "AUDIO",
|
||||
"linkIds": [
|
||||
112
|
||||
],
|
||||
"pos": [
|
||||
1544,
|
||||
184
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": "4c0e5484-d193-49c7-b107-92619628880a",
|
||||
"name": "fps",
|
||||
"type": "FLOAT",
|
||||
"linkIds": [
|
||||
113
|
||||
],
|
||||
"pos": [
|
||||
1544,
|
||||
204
|
||||
]
|
||||
}
|
||||
],
|
||||
"widgets": [],
|
||||
"nodes": [
|
||||
{
|
||||
"id": 92,
|
||||
"type": "DA3Render",
|
||||
"pos": [
|
||||
740,
|
||||
230
|
||||
],
|
||||
"size": [
|
||||
380,
|
||||
130
|
||||
],
|
||||
"flags": {},
|
||||
"order": 0,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "da3_geometry",
|
||||
"name": "da3_geometry",
|
||||
"type": "DA3_GEOMETRY",
|
||||
"link": 12
|
||||
},
|
||||
{
|
||||
"localized_name": "output",
|
||||
"name": "output",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "output"
|
||||
},
|
||||
"link": 126
|
||||
},
|
||||
{
|
||||
"localized_name": "output.normalization",
|
||||
"name": "output.normalization",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "output.normalization"
|
||||
},
|
||||
"link": 127
|
||||
},
|
||||
{
|
||||
"localized_name": "output.apply_sky_clip",
|
||||
"name": "output.apply_sky_clip",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "output.apply_sky_clip"
|
||||
},
|
||||
"link": 128
|
||||
},
|
||||
{
|
||||
"name": "geometry",
|
||||
"type": "DA3_GEOMETRY",
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "IMAGE",
|
||||
"name": "IMAGE",
|
||||
"type": "IMAGE",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
7
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "DA3Render",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0"
|
||||
},
|
||||
"widgets_values": [
|
||||
"depth",
|
||||
"v2_style",
|
||||
false
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 93,
|
||||
"type": "DA3Inference",
|
||||
"pos": [
|
||||
740,
|
||||
-30
|
||||
],
|
||||
"size": [
|
||||
390,
|
||||
130
|
||||
],
|
||||
"flags": {},
|
||||
"order": 1,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "da3_model",
|
||||
"name": "da3_model",
|
||||
"type": "DA3_MODEL",
|
||||
"link": 107
|
||||
},
|
||||
{
|
||||
"localized_name": "image",
|
||||
"name": "image",
|
||||
"type": "IMAGE",
|
||||
"link": 111
|
||||
},
|
||||
{
|
||||
"localized_name": "resolution",
|
||||
"name": "resolution",
|
||||
"type": "INT",
|
||||
"widget": {
|
||||
"name": "resolution"
|
||||
},
|
||||
"link": 124
|
||||
},
|
||||
{
|
||||
"localized_name": "resize_method",
|
||||
"name": "resize_method",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "resize_method"
|
||||
},
|
||||
"link": 125
|
||||
},
|
||||
{
|
||||
"localized_name": "mode",
|
||||
"name": "mode",
|
||||
"type": "COMFY_DYNAMICCOMBO_V3",
|
||||
"widget": {
|
||||
"name": "mode"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "da3_geometry",
|
||||
"name": "da3_geometry",
|
||||
"type": "DA3_GEOMETRY",
|
||||
"slot_index": 0,
|
||||
"links": [
|
||||
12
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "DA3Inference",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.19.0"
|
||||
},
|
||||
"widgets_values": [
|
||||
504,
|
||||
"lower_bound_resize",
|
||||
"mono"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 94,
|
||||
"type": "LoadDA3Model",
|
||||
"pos": [
|
||||
50,
|
||||
410
|
||||
],
|
||||
"size": [
|
||||
400,
|
||||
140
|
||||
],
|
||||
"flags": {},
|
||||
"order": 2,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "model_name",
|
||||
"name": "model_name",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "model_name"
|
||||
},
|
||||
"link": 129
|
||||
},
|
||||
{
|
||||
"localized_name": "weight_dtype",
|
||||
"name": "weight_dtype",
|
||||
"type": "COMBO",
|
||||
"widget": {
|
||||
"name": "weight_dtype"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "DA3_MODEL",
|
||||
"name": "DA3_MODEL",
|
||||
"type": "DA3_MODEL",
|
||||
"links": [
|
||||
107
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "LoadDA3Model",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.24.0",
|
||||
"models": [
|
||||
{
|
||||
"name": "depth_anything_3_mono_large.safetensors",
|
||||
"url": "https://huggingface.co/Comfy-Org/Depth-Anything-3/resolve/main/geometry_estimation/depth_anything_3_mono_large.safetensors",
|
||||
"directory": "geometry_estimation"
|
||||
}
|
||||
]
|
||||
},
|
||||
"widgets_values": [
|
||||
"depth_anything_3_mono_large.safetensors",
|
||||
"default"
|
||||
]
|
||||
},
|
||||
{
|
||||
"id": 95,
|
||||
"type": "GetVideoComponents",
|
||||
"pos": [
|
||||
70,
|
||||
-140
|
||||
],
|
||||
"size": [
|
||||
260,
|
||||
120
|
||||
],
|
||||
"flags": {},
|
||||
"order": 3,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "video",
|
||||
"name": "video",
|
||||
"type": "VIDEO",
|
||||
"link": 120
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "images",
|
||||
"name": "images",
|
||||
"type": "IMAGE",
|
||||
"links": [
|
||||
111
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "audio",
|
||||
"name": "audio",
|
||||
"type": "AUDIO",
|
||||
"links": [
|
||||
112
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "fps",
|
||||
"name": "fps",
|
||||
"type": "FLOAT",
|
||||
"links": [
|
||||
113
|
||||
]
|
||||
},
|
||||
{
|
||||
"localized_name": "bit_depth",
|
||||
"name": "bit_depth",
|
||||
"type": "INT",
|
||||
"links": null
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "GetVideoComponents",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.24.0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 96,
|
||||
"type": "Video Slice",
|
||||
"pos": [
|
||||
70,
|
||||
-360
|
||||
],
|
||||
"size": [
|
||||
270,
|
||||
170
|
||||
],
|
||||
"flags": {},
|
||||
"order": 4,
|
||||
"mode": 0,
|
||||
"inputs": [
|
||||
{
|
||||
"localized_name": "video",
|
||||
"name": "video",
|
||||
"type": "VIDEO",
|
||||
"link": 119
|
||||
},
|
||||
{
|
||||
"localized_name": "start_time",
|
||||
"name": "start_time",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "start_time"
|
||||
},
|
||||
"link": 121
|
||||
},
|
||||
{
|
||||
"localized_name": "duration",
|
||||
"name": "duration",
|
||||
"type": "FLOAT",
|
||||
"widget": {
|
||||
"name": "duration"
|
||||
},
|
||||
"link": 122
|
||||
},
|
||||
{
|
||||
"localized_name": "strict_duration",
|
||||
"name": "strict_duration",
|
||||
"type": "BOOLEAN",
|
||||
"widget": {
|
||||
"name": "strict_duration"
|
||||
},
|
||||
"link": null
|
||||
}
|
||||
],
|
||||
"outputs": [
|
||||
{
|
||||
"localized_name": "VIDEO",
|
||||
"name": "VIDEO",
|
||||
"type": "VIDEO",
|
||||
"links": [
|
||||
120
|
||||
]
|
||||
}
|
||||
],
|
||||
"properties": {
|
||||
"Node name for S&R": "Video Slice",
|
||||
"cnr_id": "comfy-core",
|
||||
"ver": "0.24.0"
|
||||
},
|
||||
"widgets_values": [
|
||||
0,
|
||||
5,
|
||||
false
|
||||
]
|
||||
}
|
||||
],
|
||||
"groups": [],
|
||||
"links": [
|
||||
{
|
||||
"id": 12,
|
||||
"origin_id": 93,
|
||||
"origin_slot": 0,
|
||||
"target_id": 92,
|
||||
"target_slot": 0,
|
||||
"type": "DA3_GEOMETRY"
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"origin_id": 92,
|
||||
"origin_slot": 0,
|
||||
"target_id": -20,
|
||||
"target_slot": 0,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 107,
|
||||
"origin_id": 94,
|
||||
"origin_slot": 0,
|
||||
"target_id": 93,
|
||||
"target_slot": 0,
|
||||
"type": "DA3_MODEL"
|
||||
},
|
||||
{
|
||||
"id": 111,
|
||||
"origin_id": 95,
|
||||
"origin_slot": 0,
|
||||
"target_id": 93,
|
||||
"target_slot": 1,
|
||||
"type": "IMAGE"
|
||||
},
|
||||
{
|
||||
"id": 112,
|
||||
"origin_id": 95,
|
||||
"origin_slot": 1,
|
||||
"target_id": -20,
|
||||
"target_slot": 1,
|
||||
"type": "AUDIO"
|
||||
},
|
||||
{
|
||||
"id": 113,
|
||||
"origin_id": 95,
|
||||
"origin_slot": 2,
|
||||
"target_id": -20,
|
||||
"target_slot": 2,
|
||||
"type": "FLOAT"
|
||||
},
|
||||
{
|
||||
"id": 119,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 0,
|
||||
"target_id": 96,
|
||||
"target_slot": 0,
|
||||
"type": "VIDEO"
|
||||
},
|
||||
{
|
||||
"id": 120,
|
||||
"origin_id": 96,
|
||||
"origin_slot": 0,
|
||||
"target_id": 95,
|
||||
"target_slot": 0,
|
||||
"type": "VIDEO"
|
||||
},
|
||||
{
|
||||
"id": 121,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 1,
|
||||
"target_id": 96,
|
||||
"target_slot": 1,
|
||||
"type": "FLOAT"
|
||||
},
|
||||
{
|
||||
"id": 122,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 2,
|
||||
"target_id": 96,
|
||||
"target_slot": 2,
|
||||
"type": "FLOAT"
|
||||
},
|
||||
{
|
||||
"id": 124,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 3,
|
||||
"target_id": 93,
|
||||
"target_slot": 2,
|
||||
"type": "INT"
|
||||
},
|
||||
{
|
||||
"id": 125,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 4,
|
||||
"target_id": 93,
|
||||
"target_slot": 3,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 126,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 5,
|
||||
"target_id": 92,
|
||||
"target_slot": 1,
|
||||
"type": "COMFY_DYNAMICCOMBO_V3"
|
||||
},
|
||||
{
|
||||
"id": 127,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 6,
|
||||
"target_id": 92,
|
||||
"target_slot": 2,
|
||||
"type": "COMBO"
|
||||
},
|
||||
{
|
||||
"id": 128,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 7,
|
||||
"target_id": 92,
|
||||
"target_slot": 3,
|
||||
"type": "BOOLEAN"
|
||||
},
|
||||
{
|
||||
"id": 129,
|
||||
"origin_id": -10,
|
||||
"origin_slot": 8,
|
||||
"target_id": 94,
|
||||
"target_slot": 0,
|
||||
"type": "COMBO"
|
||||
}
|
||||
],
|
||||
"extra": {},
|
||||
"category": "Conditioning & Preprocessors/Depth",
|
||||
"description": "This subgraph processes a video input through Depth Anything 3 to produce temporally consistent depth maps for each frame, outputting a depth video. It is ideal for video content requiring spatial geometry estimation, such as 3D reconstruction, SLAM, or novel view synthesis from moving cameras. The model uses a plain transformer backbone trained with a depth-ray representation, supporting any number of views without requiring known camera poses."
|
||||
}
|
||||
]
|
||||
},
|
||||
"extra": {
|
||||
"BlueprintDescription": "This subgraph processes a video input through Depth Anything 3 to produce temporally consistent depth maps for each frame, outputting a depth video. It is ideal for video content requiring spatial geometry estimation, such as 3D reconstruction, SLAM, or novel view synthesis from moving cameras. The model uses a plain transformer backbone trained with a depth-ray representation, supporting any number of views without requiring known camera poses."
|
||||
}
|
||||
}
|
||||
3732
blueprints/Video Edit (Bernini-R).json
Normal file
3732
blueprints/Video Edit (Bernini-R).json
Normal file
File diff suppressed because it is too large
Load Diff
@ -115,6 +115,7 @@ cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metav
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--high-ram", action="store_true", help="Can improve performance slightly on high RAM or on systems where pagefile use is preferred over model loading.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -133,7 +134,7 @@ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager. Implies --enable-manager.")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
@ -144,6 +145,7 @@ vram_group.add_argument("--novram", action="store_true", help="When lowvram isn'
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
parser.add_argument("--vram-headroom", type=float, default=0, help="Set the amount of vram in GB for DynamicVRAM to maintain as extra headroom above default. ComfyUI will try and keep this much VRAM completely free and unused, even counting VRAM from other apps.")
|
||||
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
@ -166,6 +168,8 @@ class PerformanceFeature(enum.Enum):
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.")
|
||||
|
||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
@ -247,6 +251,9 @@ else:
|
||||
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
||||
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
||||
|
||||
if args.high_ram:
|
||||
args.cache_classic = True
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
@ -256,6 +263,10 @@ if args.disable_auto_launch:
|
||||
if args.force_fp16:
|
||||
args.fp16_unet = True
|
||||
|
||||
# '--enable-manager-legacy-ui' is meaningless unless the manager is enabled, so imply '--enable-manager'.
|
||||
if args.enable_manager_legacy_ui:
|
||||
args.enable_manager = True
|
||||
|
||||
|
||||
# '--fast' is not provided, use an empty set
|
||||
if args.fast is None:
|
||||
|
||||
@ -8,6 +8,8 @@ from abc import ABC, abstractmethod
|
||||
import logging
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
import comfy.conds
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
@ -51,12 +53,18 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0, modality_windows: dict=None, context_overlap: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.context_overlap = context_overlap
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
self.modality_windows = modality_windows # dict of {mod_idx: IndexListContextWindow}
|
||||
self.guide_frames_indices: list[int] = []
|
||||
self.guide_overlap_info: list[tuple[int, int]] = []
|
||||
self.guide_kf_local_positions: list[int] = []
|
||||
self.guide_downscale_factors: list[int] = []
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
@ -85,6 +93,11 @@ class IndexListContextWindow(ContextWindowABC):
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
def get_window_for_modality(self, modality_idx: int) -> 'IndexListContextWindow':
|
||||
if modality_idx == 0:
|
||||
return self
|
||||
return self.modality_windows[modality_idx]
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
@ -148,6 +161,172 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
|
||||
def compute_guide_overlap(guide_entries: list[dict], keyframe_idxs: torch.Tensor, temporal_downscale_ratio: int, window_index_list: list[int]):
|
||||
"""Compute which concatenated guide frames overlap with a context window.
|
||||
|
||||
Each guide's latent-space start is derived from its first token's pixel-t-start
|
||||
in keyframe_idxs (shape (B, [t,h,w], num_tokens, [start, end])), divided by the
|
||||
model's temporal_downscale_ratio.
|
||||
|
||||
Args:
|
||||
guide_entries: list of guide_attention_entry dicts
|
||||
keyframe_idxs: per-token pixel coords cond tensor for the modality
|
||||
temporal_downscale_ratio: model's pixel-to-latent temporal compression ratio
|
||||
window_index_list: the window's frame indices into the video portion
|
||||
|
||||
Returns:
|
||||
suffix_indices: indices into the guide_frames tensor for frame selection
|
||||
overlap_info: list of (entry_idx, overlap_count) for guide_attention_entries adjustment
|
||||
kf_local_positions: window-local frame positions for keyframe_idxs regeneration
|
||||
total_overlap: total number of overlapping guide frames
|
||||
"""
|
||||
window_set = set(window_index_list)
|
||||
window_list = list(window_index_list)
|
||||
suffix_indices = []
|
||||
overlap_info = []
|
||||
kf_local_positions = []
|
||||
suffix_base = 0
|
||||
token_offset = 0
|
||||
|
||||
for entry_idx, entry in enumerate(guide_entries):
|
||||
first_t_pixel = int(keyframe_idxs[0, 0, token_offset, 0].item())
|
||||
latent_start = (first_t_pixel + temporal_downscale_ratio - 1) // temporal_downscale_ratio
|
||||
guide_len = entry["latent_shape"][0]
|
||||
entry_overlap = 0
|
||||
|
||||
for local_offset in range(guide_len):
|
||||
video_pos = latent_start + local_offset
|
||||
if video_pos in window_set:
|
||||
suffix_indices.append(suffix_base + local_offset)
|
||||
kf_local_positions.append(window_list.index(video_pos))
|
||||
entry_overlap += 1
|
||||
|
||||
if entry_overlap > 0:
|
||||
overlap_info.append((entry_idx, entry_overlap))
|
||||
suffix_base += guide_len
|
||||
token_offset += entry["pre_filter_count"]
|
||||
|
||||
return suffix_indices, overlap_info, kf_local_positions, len(suffix_indices)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WindowingState:
|
||||
"""Per-modality context windowing state for each step,
|
||||
built using IndexListContextHandler._build_window_state().
|
||||
For non-multimodal models the lists are length 1
|
||||
"""
|
||||
latents: list[torch.Tensor] # per-modality working latents (guide frames stripped)
|
||||
guide_latents: list[torch.Tensor | None] # per-modality guide frames stripped from latents
|
||||
guide_entries: list[list[dict] | None] # per-modality guide_attention_entry metadata
|
||||
keyframe_idxs: list[torch.Tensor | None] # per-modality keyframe_idxs tensor for guide latent_start derivation
|
||||
latent_shapes: list | None # original packed shapes for unpack/pack (None if not multimodal)
|
||||
dim: int = 0 # primary modality temporal dim for context windowing
|
||||
is_multimodal: bool = False
|
||||
temporal_downscale_ratio: int = 1 # model's pixel-to-latent temporal compression ratio
|
||||
|
||||
def prepare_window(self, window: IndexListContextWindow, model) -> IndexListContextWindow:
|
||||
"""Reformat window for multimodal contexts by deriving per-modality index lists.
|
||||
Non-multimodal contexts return the input window unchanged.
|
||||
"""
|
||||
if not self.is_multimodal:
|
||||
return window
|
||||
|
||||
x = self.latents[0]
|
||||
primary_total = self.latent_shapes[0][self.dim]
|
||||
primary_overlap = window.context_overlap
|
||||
map_shapes = self.latent_shapes
|
||||
if x.size(self.dim) != primary_total:
|
||||
map_shapes = list(self.latent_shapes)
|
||||
video_shape = list(self.latent_shapes[0])
|
||||
video_shape[self.dim] = x.size(self.dim)
|
||||
map_shapes[0] = torch.Size(video_shape)
|
||||
try:
|
||||
per_modality_indices = model.map_context_window_to_modalities(
|
||||
window.index_list, map_shapes, self.dim)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
f"{type(model).__name__} must implement map_context_window_to_modalities for multimodal context windows.")
|
||||
modality_windows = {}
|
||||
for mod_idx in range(1, len(self.latents)):
|
||||
modality_total_frames = self.latents[mod_idx].shape[self.dim]
|
||||
ratio = modality_total_frames / primary_total if primary_total > 0 else 1
|
||||
modality_overlap = max(round(primary_overlap * ratio), 0)
|
||||
modality_windows[mod_idx] = IndexListContextWindow(
|
||||
per_modality_indices[mod_idx], dim=self.dim,
|
||||
total_frames=modality_total_frames,
|
||||
context_overlap=modality_overlap)
|
||||
return IndexListContextWindow(
|
||||
window.index_list, dim=self.dim, total_frames=x.shape[self.dim],
|
||||
modality_windows=modality_windows, context_overlap=primary_overlap)
|
||||
|
||||
def slice_for_window(self, window: IndexListContextWindow, retain_index_list: list[int], device=None) -> tuple[list[torch.Tensor], list[int]]:
|
||||
"""Slice latents for a context window, injecting guide frames where applicable.
|
||||
For multimodal contexts, uses the modality-specific windows derived in prepare_window().
|
||||
"""
|
||||
sliced = []
|
||||
guide_frame_counts = []
|
||||
for idx in range(len(self.latents)):
|
||||
modality_window = window.get_window_for_modality(idx)
|
||||
retain = retain_index_list if idx == 0 else []
|
||||
s = modality_window.get_tensor(self.latents[idx], device, retain_index_list=retain)
|
||||
if self.guide_entries[idx] is not None:
|
||||
s, ng = self._inject_guide_frames(s, modality_window, modality_idx=idx)
|
||||
else:
|
||||
ng = 0
|
||||
sliced.append(s)
|
||||
guide_frame_counts.append(ng)
|
||||
return sliced, guide_frame_counts
|
||||
|
||||
def strip_guide_frames(self, out_per_modality: list[list[torch.Tensor]], guide_frame_counts: list[int], window: IndexListContextWindow):
|
||||
"""Strip injected guide frames from per-cond, per-modality outputs in place."""
|
||||
for idx in range(len(self.latents)):
|
||||
if guide_frame_counts[idx] > 0:
|
||||
window_len = len(window.get_window_for_modality(idx).index_list)
|
||||
for ci in range(len(out_per_modality)):
|
||||
out_per_modality[ci][idx] = out_per_modality[ci][idx].narrow(self.dim, 0, window_len)
|
||||
|
||||
def _inject_guide_frames(self, latent_slice: torch.Tensor, window: IndexListContextWindow, modality_idx: int = 0) -> tuple[torch.Tensor, int]:
|
||||
guide_entries = self.guide_entries[modality_idx]
|
||||
guide_frames = self.guide_latents[modality_idx]
|
||||
keyframe_idxs = self.keyframe_idxs[modality_idx]
|
||||
suffix_idx, overlap_info, kf_local_pos, guide_frame_count = compute_guide_overlap(
|
||||
guide_entries, keyframe_idxs, self.temporal_downscale_ratio, window.index_list)
|
||||
# Shift keyframe positions to account for causal_window_fix anchor occupying sub-pos 0.
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
kf_local_pos = [p + 1 for p in kf_local_pos]
|
||||
window.guide_frames_indices = suffix_idx
|
||||
window.guide_overlap_info = overlap_info
|
||||
window.guide_kf_local_positions = kf_local_pos
|
||||
|
||||
# Derive per-overlap-entry latent_downscale_factor from guide entry latent_shape vs guide frame spatial dims.
|
||||
# guide_frames has full (post-dilation) spatial dims; entry["latent_shape"] has pre-dilation dims.
|
||||
guide_downscale_factors = []
|
||||
if guide_frame_count > 0:
|
||||
full_H = guide_frames.shape[3]
|
||||
for entry_idx, _ in overlap_info:
|
||||
entry_H = guide_entries[entry_idx]["latent_shape"][1]
|
||||
guide_downscale_factors.append(full_H // entry_H)
|
||||
window.guide_downscale_factors = guide_downscale_factors
|
||||
|
||||
if guide_frame_count > 0:
|
||||
idx = tuple([slice(None)] * self.dim + [suffix_idx])
|
||||
return torch.cat([latent_slice, guide_frames[idx]], dim=self.dim), guide_frame_count
|
||||
return latent_slice, 0
|
||||
|
||||
def patch_latent_shapes(self, sub_conds, new_shapes):
|
||||
if not self.is_multimodal:
|
||||
return
|
||||
|
||||
for cond_list in sub_conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
model_conds['latent_shapes'] = comfy.conds.CONDConstant(new_shapes)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextSchedule:
|
||||
name: str
|
||||
@ -162,7 +341,7 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
|
||||
causal_window_fix: bool=True):
|
||||
latent_retain_index_list: list[int]=[], causal_window_fix: bool=True):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -174,17 +353,118 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
self.latent_retain_index_list = [int(x.strip()) for x in latent_retain_index_list.split(",")] if latent_retain_index_list else []
|
||||
self.causal_window_fix = causal_window_fix
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_latent_shapes(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
if 'latent_shapes' in model_conds:
|
||||
return model_conds['latent_shapes'].cond
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_guide_entries(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
entries = model_conds.get('guide_attention_entries')
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_keyframe_idxs(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
kf = model_conds.get('keyframe_idxs')
|
||||
if kf is not None and hasattr(kf, 'cond') and kf.cond is not None:
|
||||
return kf.cond
|
||||
return None
|
||||
|
||||
def _apply_freenoise(self, noise: torch.Tensor, conds: list[list[dict]], seed: int) -> torch.Tensor:
|
||||
"""Apply FreeNoise shuffling, scaling context length/overlap per-modality by frame ratio.
|
||||
If guide frames are present on the primary modality, only the video portion is shuffled.
|
||||
"""
|
||||
guide_entries = self._get_guide_entries(conds)
|
||||
guide_count = sum(e["latent_shape"][0] for e in guide_entries) if guide_entries else 0
|
||||
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
if latent_shapes is not None and len(latent_shapes) > 1:
|
||||
modalities = comfy.utils.unpack_latents(noise, latent_shapes)
|
||||
primary_total = latent_shapes[0][self.dim]
|
||||
primary_video_count = modalities[0].size(self.dim) - guide_count
|
||||
apply_freenoise(modalities[0].narrow(self.dim, 0, primary_video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||
for i in range(1, len(modalities)):
|
||||
mod_total = latent_shapes[i][self.dim]
|
||||
ratio = mod_total / primary_total if primary_total > 0 else 1
|
||||
mod_ctx_len = max(round(self.context_length * ratio), 1)
|
||||
mod_ctx_overlap = max(round(self.context_overlap * ratio), 0)
|
||||
modalities[i] = apply_freenoise(modalities[i], self.dim, mod_ctx_len, mod_ctx_overlap, seed)
|
||||
noise, _ = comfy.utils.pack_latents(modalities)
|
||||
return noise
|
||||
video_count = noise.size(self.dim) - guide_count
|
||||
apply_freenoise(noise.narrow(self.dim, 0, video_count), self.dim, self.context_length, self.context_overlap, seed)
|
||||
return noise
|
||||
|
||||
def _build_window_state(self, x_in: torch.Tensor, conds: list[list[dict]], model: BaseModel) -> WindowingState:
|
||||
"""Build windowing state for the current step, including unpacking latents and extracting guide frame info from conds."""
|
||||
latent_shapes = self._get_latent_shapes(conds)
|
||||
is_multimodal = latent_shapes is not None and len(latent_shapes) > 1
|
||||
unpacked_latents = comfy.utils.unpack_latents(x_in, latent_shapes) if is_multimodal else [x_in]
|
||||
|
||||
unpacked_latents_list = list(unpacked_latents)
|
||||
guide_latents_list = [None] * len(unpacked_latents)
|
||||
guide_entries_list = [None] * len(unpacked_latents)
|
||||
keyframe_idxs_list = [None] * len(unpacked_latents)
|
||||
|
||||
extracted_guide_entries = self._get_guide_entries(conds)
|
||||
extracted_keyframe_idxs = self._get_keyframe_idxs(conds)
|
||||
|
||||
# Strip guide frames (only from first modality for now)
|
||||
if extracted_guide_entries is not None:
|
||||
guide_count = sum(e["latent_shape"][0] for e in extracted_guide_entries)
|
||||
if guide_count > 0:
|
||||
x = unpacked_latents[0]
|
||||
latent_count = x.size(self.dim) - guide_count
|
||||
unpacked_latents_list[0] = x.narrow(self.dim, 0, latent_count)
|
||||
guide_latents_list[0] = x.narrow(self.dim, latent_count, guide_count)
|
||||
guide_entries_list[0] = extracted_guide_entries
|
||||
keyframe_idxs_list[0] = extracted_keyframe_idxs
|
||||
|
||||
|
||||
return WindowingState(
|
||||
latents=unpacked_latents_list,
|
||||
guide_latents=guide_latents_list,
|
||||
guide_entries=guide_entries_list,
|
||||
keyframe_idxs=keyframe_idxs_list,
|
||||
latent_shapes=latent_shapes,
|
||||
dim=self.dim,
|
||||
is_multimodal=is_multimodal,
|
||||
temporal_downscale_ratio=model.latent_format.temporal_downscale_ratio)
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
window_state = self._build_window_state(x_in, conds, model) # build window_state to check frame counts, will be built again in execute
|
||||
total_frame_count = window_state.latents[0].size(self.dim)
|
||||
if total_frame_count > self.context_length:
|
||||
logging.info(f"\nUsing context windows: Context length {self.context_length} with overlap {self.context_overlap} for {total_frame_count} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
if self.latent_retain_index_list:
|
||||
logging.info(f"Retaining original latent for indexes: {self.latent_retain_index_list}")
|
||||
return True
|
||||
logging.info(f"\nNot using context windows since context length ({self.context_length}) exceeds input frames ({total_frame_count}).")
|
||||
return False
|
||||
|
||||
def prepare_control_objects(self, control: ControlBase, device=None) -> ControlBase:
|
||||
@ -275,7 +555,9 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
sample_sigmas = model_options["transformer_options"]["sample_sigmas"]
|
||||
current_timestep = timestep[0].to(sample_sigmas.dtype)
|
||||
mask = torch.isclose(sample_sigmas, current_timestep, rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
return # substep from multi-step sampler: keep self._step from the last full step
|
||||
@ -284,54 +566,98 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length, context_overlap=self.context_overlap) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
self._model = model
|
||||
self.set_step(timestep, model_options)
|
||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
|
||||
conds_final = [torch.zeros_like(x_in) for _ in conds]
|
||||
window_state = self._build_window_state(x_in, conds, model)
|
||||
num_modalities = len(window_state.latents)
|
||||
|
||||
context_windows = self.get_context_windows(model, window_state.latents[0], model_options)
|
||||
enumerated_context_windows = list(enumerate(context_windows))
|
||||
total_windows = len(enumerated_context_windows)
|
||||
|
||||
# Initialize per-modality accumulators (length 1 for single-modality)
|
||||
accum = [[torch.zeros_like(m) for _ in conds] for m in window_state.latents]
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
counts_final = [torch.ones(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
counts = [[torch.ones(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
else:
|
||||
counts_final = [torch.zeros(get_shape_for_dim(x_in, self.dim), device=x_in.device) for _ in conds]
|
||||
biases_final = [([0.0] * x_in.shape[self.dim]) for _ in conds]
|
||||
counts = [[torch.zeros(get_shape_for_dim(m, self.dim), device=m.device) for _ in conds] for m in window_state.latents]
|
||||
biases = [[([0.0] * m.shape[self.dim]) for _ in conds] for m in window_state.latents]
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_START, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
# accumulate results from each context window
|
||||
for enum_window in enumerated_context_windows:
|
||||
results = self.evaluate_context_windows(calc_cond_batch, model, x_in, conds, timestep, [enum_window], model_options)
|
||||
results = self.evaluate_context_windows(
|
||||
calc_cond_batch, model, x_in, conds, timestep, [enum_window],
|
||||
model_options, window_state=window_state, total_windows=total_windows)
|
||||
for result in results:
|
||||
self.combine_context_window_results(x_in, result.sub_conds_out, result.sub_conds, result.window, result.window_idx, len(enumerated_context_windows), timestep,
|
||||
conds_final, counts_final, biases_final)
|
||||
# result.sub_conds_out is per-cond, per-modality: list[list[Tensor]]
|
||||
for mod_idx in range(num_modalities):
|
||||
mod_out = [result.sub_conds_out[ci][mod_idx] for ci in range(len(conds))]
|
||||
modality_window = result.window.get_window_for_modality(mod_idx)
|
||||
self.combine_context_window_results(
|
||||
window_state.latents[mod_idx], mod_out, result.sub_conds, modality_window,
|
||||
result.window_idx, total_windows, timestep,
|
||||
accum[mod_idx], counts[mod_idx], biases[mod_idx])
|
||||
|
||||
# fuse accumulated results into final conds
|
||||
try:
|
||||
# finalize conds
|
||||
if self.fuse_method.name == ContextFuseMethods.RELATIVE:
|
||||
# relative is already normalized, so return as is
|
||||
del counts_final
|
||||
return conds_final
|
||||
else:
|
||||
# normalize conds via division by context usage counts
|
||||
for i in range(len(conds_final)):
|
||||
conds_final[i] /= counts_final[i]
|
||||
del counts_final
|
||||
return conds_final
|
||||
result_out = []
|
||||
for ci in range(len(conds)):
|
||||
finalized = []
|
||||
for mod_idx in range(num_modalities):
|
||||
if self.fuse_method.name != ContextFuseMethods.RELATIVE:
|
||||
accum[mod_idx][ci] /= counts[mod_idx][ci]
|
||||
f = accum[mod_idx][ci]
|
||||
|
||||
# if guide frames were injected, append them to the end of the fused latents for the next step
|
||||
if window_state.guide_latents[mod_idx] is not None:
|
||||
f = torch.cat([f, window_state.guide_latents[mod_idx]], dim=self.dim)
|
||||
finalized.append(f)
|
||||
|
||||
# pack modalities together if needed
|
||||
if window_state.is_multimodal and len(finalized) > 1:
|
||||
packed, _ = comfy.utils.pack_latents(finalized)
|
||||
else:
|
||||
packed = finalized[0]
|
||||
|
||||
result_out.append(packed)
|
||||
return result_out
|
||||
finally:
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EXECUTE_CLEANUP, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options)
|
||||
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds, timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, device=None, first_device=None):
|
||||
def evaluate_context_windows(self, calc_cond_batch: Callable, model: BaseModel, x_in: torch.Tensor, conds,
|
||||
timestep: torch.Tensor, enumerated_context_windows: list[tuple[int, IndexListContextWindow]],
|
||||
model_options, window_state: WindowingState, total_windows: int = None,
|
||||
device=None, first_device=None):
|
||||
"""Evaluate context windows and return per-cond, per-modality outputs in ContextResults.sub_conds_out
|
||||
|
||||
For each window:
|
||||
1. Builds windows (for each modality if multimodal)
|
||||
2. Slices window for each modality
|
||||
3. Injects concatenated latent guide frames where present
|
||||
4. Packs together if needed and calls model
|
||||
5. Unpacks and strips any guides from outputs
|
||||
"""
|
||||
x = window_state.latents[0]
|
||||
|
||||
results: list[ContextResults] = []
|
||||
for window_idx, window in enumerated_context_windows:
|
||||
# allow processing to end between context window executions for faster Cancel
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward
|
||||
# prepare the window accounting for multimodal windows
|
||||
window = window_state.prepare_window(window, model)
|
||||
|
||||
# causal_window_fix: prepend a pre-window frame that will be stripped post-forward.
|
||||
# Set anchor before slice_for_window so the latent slice and downstream cond slices both pick it up.
|
||||
anchor_applied = False
|
||||
if self.causal_window_fix:
|
||||
anchor_idx = window.index_list[0] - 1
|
||||
@ -339,27 +665,46 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
window.causal_anchor_index = anchor_idx
|
||||
anchor_applied = True
|
||||
|
||||
# slice the window for each modality, injecting guide frames where applicable
|
||||
sliced, guide_frame_counts_per_modality = window_state.slice_for_window(window, self.latent_retain_index_list, device)
|
||||
|
||||
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
|
||||
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
|
||||
|
||||
# update exposed params
|
||||
logging.info(f"Context window {window_idx + 1}/{total_windows or len(enumerated_context_windows)}: frames {window.index_list[0]}-{window.index_list[-1]} of {x.shape[self.dim]}"
|
||||
+ (f" (+{guide_frame_counts_per_modality[0]} guide frames)" if guide_frame_counts_per_modality[0] > 0 else "")
|
||||
)
|
||||
|
||||
# if multimodal, pack modalities together
|
||||
if window_state.is_multimodal and len(sliced) > 1:
|
||||
sub_x, sub_shapes = comfy.utils.pack_latents(sliced)
|
||||
else:
|
||||
sub_x, sub_shapes = sliced[0], [sliced[0].shape]
|
||||
|
||||
# get resized conds for window
|
||||
model_options["transformer_options"]["context_window"] = window
|
||||
# get subsections of x, timestep, conds
|
||||
sub_x = window.get_tensor(x_in, device)
|
||||
sub_timestep = window.get_tensor(timestep, device, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x_in, window, device) for cond in conds]
|
||||
sub_timestep = window.get_tensor(timestep, dim=0)
|
||||
sub_conds = [self.get_resized_cond(cond, x, window) for cond in conds]
|
||||
|
||||
# if multimodal, patch latent_shapes in conds for correct unpacking in model
|
||||
window_state.patch_latent_shapes(sub_conds, sub_shapes)
|
||||
|
||||
# call model on window
|
||||
sub_conds_out = calc_cond_batch(model, sub_conds, sub_x, sub_timestep, model_options)
|
||||
if device is not None:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
|
||||
|
||||
# strip causal_window_fix anchor if applied
|
||||
# unpack outputs
|
||||
out_per_modality = [comfy.utils.unpack_latents(sub_conds_out[i], sub_shapes) for i in range(len(sub_conds_out))]
|
||||
|
||||
# strip causal_window_fix anchor from primary modality before guide strip so window_len math stays correct
|
||||
if anchor_applied:
|
||||
for i in range(len(sub_conds_out)):
|
||||
sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
|
||||
for ci in range(len(out_per_modality)):
|
||||
t = out_per_modality[ci][0]
|
||||
out_per_modality[ci][0] = t.narrow(self.dim, 1, t.shape[self.dim] - 1)
|
||||
|
||||
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
|
||||
# strip injected guide frames
|
||||
window_state.strip_guide_frames(out_per_modality, guide_frame_counts_per_modality, window)
|
||||
|
||||
results.append(ContextResults(window_idx, out_per_modality, sub_conds, window))
|
||||
return results
|
||||
|
||||
|
||||
@ -383,7 +728,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
else:
|
||||
# add conds and counts based on weights of fuse method
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep)
|
||||
weights = get_context_weights(window.context_length, x_in.shape[self.dim], window.index_list, self, sigma=timestep, context_overlap=window.context_overlap)
|
||||
weights_tensor = match_weights_to_dim(weights, x_in, self.dim, device=x_in.device)
|
||||
for i in range(len(sub_conds_out)):
|
||||
window.add_window(conds_final[i], sub_conds_out[i] * weights_tensor)
|
||||
@ -393,16 +738,22 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
callback(self, x_in, sub_conds_out, sub_conds, window, window_idx, total_windows, timestep, conds_final, counts_final, biases_final)
|
||||
|
||||
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, *args, **kwargs):
|
||||
# limit noise_shape length to context_length for more accurate vram use estimation
|
||||
def _prepare_sampling_wrapper(executor, model, noise_shape: torch.Tensor, conds, *args, **kwargs):
|
||||
# Scale noise_shape to a single context window so VRAM estimation budgets per-window.
|
||||
model_options = kwargs.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in prepare_sampling_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is not None:
|
||||
noise_shape = list(noise_shape)
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, *args, **kwargs)
|
||||
is_packed = len(noise_shape) == 3 and noise_shape[1] == 1
|
||||
if is_packed:
|
||||
# TODO: latent_shapes cond isn't attached yet at this point, so we can't compute a
|
||||
# per-window flat latent here. Skipping the clamp over-estimates but prevents immediate OOM.
|
||||
pass
|
||||
elif handler.dim < len(noise_shape) and noise_shape[handler.dim] > handler.context_length:
|
||||
noise_shape[handler.dim] = min(noise_shape[handler.dim], handler.context_length)
|
||||
return executor(model, noise_shape, conds, *args, **kwargs)
|
||||
|
||||
|
||||
def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
@ -422,11 +773,12 @@ def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, nois
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
conds = [guider.conds.get('positive', guider.conds.get('negative', []))]
|
||||
noise = handler._apply_freenoise(noise, conds, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
@ -434,7 +786,6 @@ def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
@ -580,8 +931,9 @@ def get_matching_context_schedule(context_schedule: str) -> ContextSchedule:
|
||||
return ContextSchedule(context_schedule, func)
|
||||
|
||||
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None):
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs)
|
||||
def get_context_weights(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, sigma: torch.Tensor=None, context_overlap: int=None):
|
||||
context_overlap = handler.context_overlap if context_overlap is None else context_overlap
|
||||
return handler.fuse_method.func(length, sigma=sigma, handler=handler, full_length=full_length, idxs=idxs, context_overlap=context_overlap)
|
||||
|
||||
|
||||
def create_weights_flat(length: int, **kwargs) -> list[float]:
|
||||
@ -599,18 +951,18 @@ def create_weights_pyramid(length: int, **kwargs) -> list[float]:
|
||||
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1))
|
||||
return weight_sequence
|
||||
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], handler: IndexListContextHandler, **kwargs):
|
||||
def create_weights_overlap_linear(length: int, full_length: int, idxs: list[int], context_overlap: int, **kwargs):
|
||||
# based on code in Kijai's WanVideoWrapper: https://github.com/kijai/ComfyUI-WanVideoWrapper/blob/dbb2523b37e4ccdf45127e5ae33e31362f755c8e/nodes.py#L1302
|
||||
# only expected overlap is given different weights
|
||||
weights_torch = torch.ones((length))
|
||||
# blend left-side on all except first window
|
||||
if min(idxs) > 0:
|
||||
ramp_up = torch.linspace(1e-37, 1, handler.context_overlap)
|
||||
weights_torch[:handler.context_overlap] = ramp_up
|
||||
ramp_up = torch.linspace(1e-37, 1, context_overlap)
|
||||
weights_torch[:context_overlap] = ramp_up
|
||||
# blend right-side on all except last window
|
||||
if max(idxs) < full_length-1:
|
||||
ramp_down = torch.linspace(1, 1e-37, handler.context_overlap)
|
||||
weights_torch[-handler.context_overlap:] = ramp_down
|
||||
ramp_down = torch.linspace(1, 1e-37, context_overlap)
|
||||
weights_torch[-context_overlap:] = ramp_down
|
||||
return weights_torch
|
||||
|
||||
class ContextFuseMethods:
|
||||
|
||||
@ -1,7 +1,13 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.ldm.depth_anything_3.reference_view_selector import (
|
||||
select_reference_view, reorder_by_reference, restore_original_order,
|
||||
THRESH_FOR_REF_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
@ -14,13 +20,41 @@ class Dino2AttentionOutput(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.head_dim = embed_dim // heads
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
if qk_norm:
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
def forward(self, x, mask, optimized_attention, pos=None, rope=None):
|
||||
# Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions).
|
||||
if self.q_norm is None and rope is None:
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
# DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE.
|
||||
attn = self.attention
|
||||
B, N, C = x.shape
|
||||
h = self.heads
|
||||
d = self.head_dim
|
||||
q = attn.query(x).view(B, N, h, d).transpose(1, 2)
|
||||
k = attn.key(x).view(B, N, h, d).transpose(1, 2)
|
||||
v = attn.value(x).view(B, N, h, d).transpose(1, 2)
|
||||
if self.q_norm is not None:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if rope is not None and pos is not None:
|
||||
q = rope(q, pos)
|
||||
k = rope(k, pos)
|
||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
||||
return self.output(out)
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
@ -64,9 +98,11 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=qk_norm)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
@ -76,19 +112,90 @@ class Dino2Block(torch.nn.Module):
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention,
|
||||
pos=pos, rope=rope))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2D Rotary position embedding (DA3 extension)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PositionGetter:
|
||||
"""Cache (h, w) -> flat (y, x) position grid used to feed ``rope``."""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
|
||||
key = (height, width, device)
|
||||
if key not in self._cache:
|
||||
y = torch.arange(height, device=device)
|
||||
x = torch.arange(width, device=device)
|
||||
self._cache[key] = torch.cartesian_prod(y, x)
|
||||
cached = self._cache[key]
|
||||
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(torch.nn.Module):
|
||||
"""2D RoPE used by DA3-Small/Base. No learnable parameters."""
|
||||
|
||||
def __init__(self, frequency: float = 100.0):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
self.base_frequency = frequency
|
||||
self._freq_cache: dict = {}
|
||||
|
||||
def _components(self, dim: int, seq_len: int, device, dtype):
|
||||
key = (dim, seq_len, device, dtype)
|
||||
if key not in self._freq_cache:
|
||||
exp = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency ** exp)
|
||||
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
ang = torch.einsum("i,j->ij", pos, inv_freq)
|
||||
ang = ang.to(dtype)
|
||||
ang = torch.cat((ang, ang), dim=-1)
|
||||
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
|
||||
return self._freq_cache[key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1]
|
||||
x1, x2 = x[..., : d // 2], x[..., d // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d(self, tokens, positions, cos_c, sin_c):
|
||||
cos = F.embedding(positions, cos_c)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_c)[:, None, :, :]
|
||||
return (tokens * cos) + (self._rotate(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
max_pos = int(positions.max()) + 1
|
||||
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
|
||||
v, h = tokens.chunk(2, dim=-1)
|
||||
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
|
||||
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
|
||||
return torch.cat((v, h), dim=-1)
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn,
|
||||
qknorm_start: int = -1):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([
|
||||
Dino2Block(
|
||||
dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
# Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions).
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
@ -122,16 +229,27 @@ class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
def __init__(self, dim, dtype, device, operations,
|
||||
patch_size: int = 14, image_size: int = 518,
|
||||
use_mask_token: bool = True,
|
||||
num_camera_tokens: int = 0):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
if use_mask_token:
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.mask_token = None
|
||||
if num_camera_tokens > 0:
|
||||
# DA3 stores (ref_token, src_token) pairs that get injected at the
|
||||
# alt-attn boundary; see ``Dinov2Model._inject_camera_token``.
|
||||
self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.camera_token = None
|
||||
|
||||
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
|
||||
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
|
||||
@ -140,12 +258,22 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
patch_pos = pos_embed[:, 1:]
|
||||
N = patch_pos.shape[1]
|
||||
M = int(N ** 0.5)
|
||||
assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})"
|
||||
h0 = h_pixels // self.patch_size
|
||||
w0 = w_pixels // self.patch_size
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# scale_factor is (height_scale, width_scale) -- height MUST come first;
|
||||
# swapping these only happens to work for square inputs and breaks
|
||||
# non-square paths like DA3-Small / DA3-Base multi-view.
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)
|
||||
|
||||
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
|
||||
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
|
||||
assert (h0, w0) == patch_pos.shape[-2:], (
|
||||
f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match "
|
||||
f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); "
|
||||
f"check scale_factor axis order and +0.1 rounding workaround"
|
||||
)
|
||||
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
|
||||
|
||||
@ -168,12 +296,51 @@ class Dinov2Model(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
patch_size = config_dict.get("patch_size", 14)
|
||||
image_size = config_dict.get("image_size", 518)
|
||||
use_mask_token = config_dict.get("use_mask_token", True)
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
# DA3 extensions (all default to disabled).
|
||||
self.alt_start = config_dict.get("alt_start", -1)
|
||||
self.qknorm_start = config_dict.get("qknorm_start", -1)
|
||||
self.rope_start = config_dict.get("rope_start", -1)
|
||||
self.cat_token = config_dict.get("cat_token", False)
|
||||
rope_freq = config_dict.get("rope_freq", 100.0)
|
||||
|
||||
self.embed_dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = 0
|
||||
self.patch_start_idx = 1
|
||||
|
||||
if self.rope_start != -1 and rope_freq > 0:
|
||||
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
|
||||
self._position_getter = _PositionGetter()
|
||||
else:
|
||||
self.rope = None
|
||||
self._position_getter = None
|
||||
|
||||
# camera_token shape: (1, 2, dim) -> (ref_token, src_token).
|
||||
num_cam_tokens = 2 if self.alt_start != -1 else 0
|
||||
|
||||
self.embeddings = Dino2Embeddings(
|
||||
dim, dtype, device, operations,
|
||||
patch_size=patch_size, image_size=image_size,
|
||||
use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens,
|
||||
)
|
||||
self.encoder = Dino2Encoder(
|
||||
dim, heads, layer_norm_eps, num_layers, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qknorm_start=self.qknorm_start,
|
||||
)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
if self.alt_start != -1:
|
||||
raise RuntimeError(
|
||||
"Dinov2Model.forward() is the backward-compatible CLIP-vision path and does not "
|
||||
"apply DA3 extensions (RoPE, alternating attention, camera-token injection). "
|
||||
"Use get_intermediate_layers_da3() for Depth Anything 3 models."
|
||||
)
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
@ -181,6 +348,7 @@ class Dinov2Model(torch.nn.Module):
|
||||
return x, i, pooled_output, None
|
||||
|
||||
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
|
||||
"""Single-view multi-layer feature extraction."""
|
||||
x = self.embeddings(pixel_values)
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
n_layers = len(self.encoder.layer)
|
||||
@ -197,3 +365,132 @@ class Dinov2Model(torch.nn.Module):
|
||||
if i >= max_idx:
|
||||
break
|
||||
return [cache[i] for i in resolved]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Depth Anything 3 forward
|
||||
# ------------------------------------------------------------------
|
||||
def _prepare_rope_positions(self, B, S, H, W, device):
|
||||
if self.rope is None:
|
||||
return None, None
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
pos = self._position_getter(B * S, ph, pw, device=device)
|
||||
# Shift so the cls/cam token at position 0 is reserved for "no diff".
|
||||
pos = pos + 1
|
||||
cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
|
||||
# Per-view local: real grid positions for patches, 0 for cls token.
|
||||
pos_local = torch.cat([cls_pos, pos], dim=1)
|
||||
# Global (across views): same grid positions; cls token still at 0,
|
||||
# but patches share the same positions in every view.
|
||||
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
|
||||
return pos_local, pos_global
|
||||
|
||||
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, cam_token: "torch.Tensor | None") -> torch.Tensor:
|
||||
# x: (B, S, N, C). Replace token at index 0 with the camera token.
|
||||
if cam_token is not None:
|
||||
inj = cam_token
|
||||
else:
|
||||
ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype)
|
||||
ref_token = ct[:, :1].expand(B, -1, -1)
|
||||
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
|
||||
inj = torch.cat([ref_token, src_token], dim=1)
|
||||
x = x.clone()
|
||||
x[:, :, 0] = inj
|
||||
return x
|
||||
|
||||
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None, ref_view_strategy="saddle_balanced", export_feat_layers=None):
|
||||
"""Multi-view multi-layer feature extraction used by Depth Anything 3."""
|
||||
if pixel_values.ndim == 4:
|
||||
pixel_values = pixel_values.unsqueeze(1)
|
||||
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
|
||||
f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}"
|
||||
B, S, _, H, W = pixel_values.shape
|
||||
|
||||
# Patch + cls + (interpolated) pos embed for each view.
|
||||
x = pixel_values.reshape(B * S, 3, H, W)
|
||||
x = self.embeddings(x) # (B*S, 1+N, C)
|
||||
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
|
||||
|
||||
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
|
||||
# optimized_attention is only used by blocks without QK-norm/RoPE
|
||||
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
out_set = set(out_layers)
|
||||
export_set = set(export_feat_layers) if export_feat_layers else set()
|
||||
outputs: list[torch.Tensor] = []
|
||||
aux_outputs: list[torch.Tensor] = []
|
||||
local_x = x
|
||||
b_idx = None
|
||||
|
||||
|
||||
for i, blk in enumerate(self.encoder.layer):
|
||||
apply_rope = self.rope is not None and i >= self.rope_start
|
||||
block_rope = self.rope if apply_rope else None
|
||||
l_pos = pos_local if apply_rope else None
|
||||
g_pos = pos_global if apply_rope else None
|
||||
|
||||
# Reference-view selection threshold: matches the upstream constant
|
||||
# THRESH_FOR_REF_SELECTION = 3. Skipped when a user-supplied
|
||||
# cam_token is provided (camera info already pins the geometry).
|
||||
if (self.alt_start != -1 and i == self.alt_start - 1 and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
|
||||
b_idx = select_reference_view(x, strategy=ref_view_strategy)
|
||||
x = reorder_by_reference(x, b_idx)
|
||||
local_x = reorder_by_reference(local_x, b_idx)
|
||||
|
||||
if self.alt_start != -1 and i == self.alt_start:
|
||||
x = self._inject_camera_token(x, B, S, cam_token)
|
||||
|
||||
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
|
||||
# Global attention across views: flatten S into the seq dim.
|
||||
t = x.reshape(B, S * x.shape[-2], x.shape[-1])
|
||||
p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
else:
|
||||
# Per-view local attention.
|
||||
t = x.reshape(B * S, x.shape[-2], x.shape[-1])
|
||||
p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
local_x = x
|
||||
|
||||
if i in out_set:
|
||||
if self.cat_token:
|
||||
out_x = torch.cat([local_x, x], dim=-1)
|
||||
else:
|
||||
out_x = x
|
||||
# Restore original view order on the way out so heads see views
|
||||
# in the user's expected order.
|
||||
if b_idx is not None and self.alt_start != -1:
|
||||
out_x = restore_original_order(out_x, b_idx)
|
||||
outputs.append(out_x)
|
||||
|
||||
if i in export_set:
|
||||
aux = x
|
||||
if b_idx is not None and self.alt_start != -1:
|
||||
aux = restore_original_order(aux, b_idx)
|
||||
aux_outputs.append(aux)
|
||||
|
||||
# Apply final norm. When cat_token is set, only the right half
|
||||
# ("global" features) is normalised; the left half is left as-is to
|
||||
# match the upstream DA3 head signature.
|
||||
normed: list[torch.Tensor] = []
|
||||
cls_tokens: list[torch.Tensor] = []
|
||||
for out_x in outputs:
|
||||
cls_tokens.append(out_x[:, :, 0])
|
||||
if out_x.shape[-1] == self.embed_dim:
|
||||
normed.append(self.layernorm(out_x))
|
||||
elif out_x.shape[-1] == self.embed_dim * 2:
|
||||
left = out_x[..., :self.embed_dim]
|
||||
right = self.layernorm(out_x[..., self.embed_dim:])
|
||||
normed.append(torch.cat([left, right], dim=-1))
|
||||
else:
|
||||
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
|
||||
|
||||
# Drop cls/cam token from the patch sequence.
|
||||
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
|
||||
|
||||
# Final layernorm + drop cls token from auxiliary features too.
|
||||
aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :]
|
||||
for o in aux_outputs]
|
||||
return list(zip(normed, cls_tokens)), aux_normed
|
||||
|
||||
321
comfy/ldm/boogu/model.py
Normal file
321
comfy/ldm/boogu/model.py
Normal file
@ -0,0 +1,321 @@
|
||||
# Boogu-Image-0.1 transformer
|
||||
# Architecture is an OmniGen2 derivative (see comfy/ldm/omnigen/omnigen2.py) with an
|
||||
# added dual-stream ("double_stream") stage before the single-stream layers, conditioned
|
||||
# by a Qwen3-VL multimodal LLM. Reuses the OmniGen2/Lumina building blocks and the Flux
|
||||
# RoPE core, the only new component is the double-stream block + the hybrid forward order.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.omnigen.omnigen2 import (
|
||||
OmniGen2RotaryPosEmbed,
|
||||
Lumina2CombinedTimestepCaptionEmbedding,
|
||||
LuminaRMSNormZero,
|
||||
LuminaLayerNormContinuous,
|
||||
LuminaFeedForward,
|
||||
Attention,
|
||||
OmniGen2TransformerBlock,
|
||||
apply_rotary_emb,
|
||||
)
|
||||
|
||||
class BooguDoubleStreamProcessor(nn.Module):
|
||||
# Joint attention over [instruct ; img] with separate per-stream q/k/v and output projections.
|
||||
def __init__(self, dim, head_dim, heads, kv_heads, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
query_dim = head_dim * heads
|
||||
kv_dim = head_dim * kv_heads
|
||||
|
||||
self.img_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
|
||||
self.img_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
|
||||
self.img_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.instruct_to_q = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
|
||||
self.instruct_to_k = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
|
||||
self.instruct_to_v = operations.Linear(query_dim, kv_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
self.instruct_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
|
||||
self.img_out = operations.Linear(query_dim, query_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, attn, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}):
|
||||
batch_size = img_hidden_states.shape[0]
|
||||
L_instruct = instruct_hidden_states.shape[1]
|
||||
|
||||
img_q = self.img_to_q(img_hidden_states)
|
||||
img_k = self.img_to_k(img_hidden_states)
|
||||
img_v = self.img_to_v(img_hidden_states)
|
||||
|
||||
instruct_q = self.instruct_to_q(instruct_hidden_states)
|
||||
instruct_k = self.instruct_to_k(instruct_hidden_states)
|
||||
instruct_v = self.instruct_to_v(instruct_hidden_states)
|
||||
|
||||
# Concatenate instruction first, then image (matches reference processor order).
|
||||
query = torch.cat([instruct_q, img_q], dim=1)
|
||||
key = torch.cat([instruct_k, img_k], dim=1)
|
||||
value = torch.cat([instruct_v, img_v], dim=1)
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, attn.dim_head)
|
||||
key = key.view(batch_size, -1, attn.kv_heads, attn.dim_head)
|
||||
value = value.view(batch_size, -1, attn.kv_heads, attn.dim_head)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, rotary_emb)
|
||||
key = apply_rotary_emb(key, rotary_emb)
|
||||
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
if attn.kv_heads < attn.heads:
|
||||
key = key.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
|
||||
value = value.repeat_interleave(attn.heads // attn.kv_heads, dim=1)
|
||||
|
||||
hidden_states = optimized_attention_masked(query, key, value, attn.heads, attention_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
|
||||
# Split back to instruction/image, apply per-stream output projections, recombine.
|
||||
instruct_hidden_states = self.instruct_out(hidden_states[:, :L_instruct])
|
||||
img_hidden_states = self.img_out(hidden_states[:, L_instruct:])
|
||||
hidden_states = torch.cat([instruct_hidden_states, img_hidden_states], dim=1)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BooguJointAttention(nn.Module):
|
||||
# Holds the shared q/k RMSNorm + final output projection
|
||||
def __init__(self, dim, head_dim, heads, kv_heads, eps=1e-5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.kv_heads = kv_heads
|
||||
self.dim_head = head_dim
|
||||
self.scale = head_dim ** -0.5
|
||||
|
||||
self.norm_q = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(head_dim, eps=eps, dtype=dtype, device=device)
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(heads * head_dim, dim, bias=False, dtype=dtype, device=device),
|
||||
nn.Dropout(0.0),
|
||||
)
|
||||
self.processor = BooguDoubleStreamProcessor(dim, head_dim, heads, kv_heads, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask=None, transformer_options={}):
|
||||
return self.processor(self, img_hidden_states, instruct_hidden_states, rotary_emb, attention_mask, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class BooguDoubleStreamBlock(nn.Module):
|
||||
# Dual-stream block: joint attention over [instruct ; img] + image self-attention, each stream with its own modulation/MLP.
|
||||
def __init__(self, dim, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
head_dim = dim // num_attention_heads
|
||||
|
||||
self.img_instruct_attn = BooguJointAttention(dim, head_dim, num_attention_heads, num_kv_heads, eps=1e-5, dtype=dtype, device=device, operations=operations)
|
||||
self.img_self_attn = Attention(
|
||||
query_dim=dim, dim_head=head_dim, heads=num_attention_heads, kv_heads=num_kv_heads,
|
||||
eps=1e-5, bias=False, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.img_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations)
|
||||
self.instruct_feed_forward = LuminaFeedForward(dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
self.img_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
self.img_norm3 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
self.instruct_norm1 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
self.instruct_norm2 = LuminaRMSNormZero(embedding_dim=dim, norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.img_self_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.img_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.img_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
self.instruct_attn_norm = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.instruct_ffn_norm1 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
self.instruct_ffn_norm2 = operations.RMSNorm(dim, eps=norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, img_hidden_states, instruct_hidden_states, joint_rotary_emb, img_rotary_emb, temb, joint_attention_mask=None, img_attention_mask=None, transformer_options={}):
|
||||
L_instruct = instruct_hidden_states.shape[1]
|
||||
|
||||
img_norm1_out, img_gate_msa, img_scale_mlp, img_gate_mlp = self.img_norm1(img_hidden_states, temb)
|
||||
img_norm2_out, img_shift_mlp, _, _ = self.img_norm2(img_hidden_states, temb)
|
||||
img_norm3_out, img_gate_self, _, _ = self.img_norm3(img_hidden_states, temb)
|
||||
|
||||
instruct_norm1_out, instruct_gate_msa, instruct_scale_mlp, instruct_gate_mlp = self.instruct_norm1(instruct_hidden_states, temb)
|
||||
instruct_norm2_out, instruct_shift_mlp, _, _ = self.instruct_norm2(instruct_hidden_states, temb)
|
||||
|
||||
joint_attn_out = self.img_instruct_attn(img_norm1_out, instruct_norm1_out, joint_rotary_emb, joint_attention_mask, transformer_options=transformer_options)
|
||||
instruct_attn_out = joint_attn_out[:, :L_instruct]
|
||||
img_attn_out = joint_attn_out[:, L_instruct:]
|
||||
|
||||
img_self_attn_out = self.img_self_attn(img_norm3_out, img_norm3_out, img_attention_mask, img_rotary_emb, transformer_options=transformer_options)
|
||||
|
||||
img_hidden_states = img_hidden_states + img_gate_msa.unsqueeze(1).tanh() * self.img_attn_norm(img_attn_out)
|
||||
img_hidden_states = img_hidden_states + img_gate_self.unsqueeze(1).tanh() * self.img_self_attn_norm(img_self_attn_out)
|
||||
img_mlp_input = (1 + img_scale_mlp.unsqueeze(1)) * img_norm2_out + img_shift_mlp.unsqueeze(1)
|
||||
img_mlp_out = self.img_feed_forward(self.img_ffn_norm1(img_mlp_input))
|
||||
img_hidden_states = img_hidden_states + img_gate_mlp.unsqueeze(1).tanh() * self.img_ffn_norm2(img_mlp_out)
|
||||
|
||||
instruct_hidden_states = instruct_hidden_states + instruct_gate_msa.unsqueeze(1).tanh() * self.instruct_attn_norm(instruct_attn_out)
|
||||
instruct_mlp_input = (1 + instruct_scale_mlp.unsqueeze(1)) * instruct_norm2_out + instruct_shift_mlp.unsqueeze(1)
|
||||
instruct_mlp_out = self.instruct_feed_forward(self.instruct_ffn_norm1(instruct_mlp_input))
|
||||
instruct_hidden_states = instruct_hidden_states + instruct_gate_mlp.unsqueeze(1).tanh() * self.instruct_ffn_norm2(instruct_mlp_out)
|
||||
|
||||
return img_hidden_states, instruct_hidden_states
|
||||
|
||||
|
||||
class BooguTransformer2DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: Optional[int] = None,
|
||||
hidden_size: int = 3360,
|
||||
num_layers: int = 32,
|
||||
num_double_stream_layers: int = 8,
|
||||
num_refiner_layers: int = 2,
|
||||
num_attention_heads: int = 28,
|
||||
num_kv_heads: int = 7,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
norm_eps: float = 1e-5,
|
||||
axes_dim_rope: Tuple[int, int, int] = (40, 40, 40),
|
||||
axes_lens: Tuple[int, int, int] = (2048, 1664, 1664),
|
||||
instruction_feat_dim: int = 4096,
|
||||
timestep_scale: float = 1000.0,
|
||||
image_model=None,
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = out_channels or in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
|
||||
self.rope_embedder = OmniGen2RotaryPosEmbed(
|
||||
theta=10000,
|
||||
axes_dim=axes_dim_rope,
|
||||
axes_lens=axes_lens,
|
||||
patch_size=patch_size,
|
||||
)
|
||||
|
||||
self.x_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||
self.ref_image_patch_embedder = operations.Linear(patch_size * patch_size * in_channels, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
|
||||
hidden_size=hidden_size,
|
||||
text_feat_dim=instruction_feat_dim,
|
||||
norm_eps=norm_eps,
|
||||
timestep_scale=timestep_scale, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.noise_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.ref_image_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.context_refiner = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=False, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_refiner_layers)
|
||||
])
|
||||
|
||||
self.double_stream_layers = nn.ModuleList([
|
||||
BooguDoubleStreamBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_double_stream_layers)
|
||||
])
|
||||
|
||||
self.single_stream_layers = nn.ModuleList([
|
||||
OmniGen2TransformerBlock(hidden_size, num_attention_heads, num_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, modulation=True, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = LuminaLayerNormContinuous(
|
||||
embedding_dim=hidden_size,
|
||||
conditioning_embedding_dim=min(hidden_size, 1024),
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
out_dim=patch_size * patch_size * self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.image_index_embedding = nn.Parameter(torch.empty(5, hidden_size, device=device, dtype=dtype))
|
||||
|
||||
# Patchify/refine helpers are identical to OmniGen2; reuse via bound methods.
|
||||
flat_and_pad_to_seq = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.flat_and_pad_to_seq
|
||||
img_patch_embed_and_refine = comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel.img_patch_embed_and_refine
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, ref_latents=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
_, _, H_padded, W_padded = hidden_states.shape
|
||||
timestep = 1.0 - timesteps
|
||||
text_hidden_states = context
|
||||
text_attention_mask = attention_mask
|
||||
ref_image_hidden_states = ref_latents
|
||||
device = hidden_states.device
|
||||
|
||||
temb, text_hidden_states = self.time_caption_embed(timestep, text_hidden_states, hidden_states[0].dtype)
|
||||
|
||||
(
|
||||
hidden_states, ref_image_hidden_states,
|
||||
img_mask, ref_img_mask,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes,
|
||||
) = self.flat_and_pad_to_seq(hidden_states, ref_image_hidden_states)
|
||||
|
||||
(
|
||||
context_rotary_emb, ref_img_rotary_emb, noise_rotary_emb,
|
||||
rotary_emb, encoder_seq_lengths, seq_lengths,
|
||||
) = self.rope_embedder(
|
||||
hidden_states.shape[0], text_hidden_states.shape[1], [num_tokens] * text_hidden_states.shape[0],
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
ref_img_sizes, img_sizes, device,
|
||||
)
|
||||
|
||||
for layer in self.context_refiner:
|
||||
text_hidden_states = layer(text_hidden_states, text_attention_mask, context_rotary_emb, transformer_options=transformer_options)
|
||||
|
||||
img_len = hidden_states.shape[1]
|
||||
combined_img_hidden_states = self.img_patch_embed_and_refine(
|
||||
hidden_states, ref_image_hidden_states,
|
||||
img_mask, ref_img_mask,
|
||||
noise_rotary_emb, ref_img_rotary_emb,
|
||||
l_effective_ref_img_len, l_effective_img_len,
|
||||
temb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
# Double-stream stage: the image self-attention only sees the [ref ; noise] tokens,
|
||||
# which sit after the instruction tokens in the joint rope.
|
||||
L_instruct = text_hidden_states.shape[1]
|
||||
combined_img_rotary_emb = rotary_emb[:, L_instruct:]
|
||||
for layer in self.double_stream_layers:
|
||||
combined_img_hidden_states, text_hidden_states = layer(
|
||||
combined_img_hidden_states, text_hidden_states,
|
||||
rotary_emb, combined_img_rotary_emb, temb,
|
||||
joint_attention_mask=None, img_attention_mask=None,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([text_hidden_states, combined_img_hidden_states], dim=1)
|
||||
|
||||
for layer in self.single_stream_layers:
|
||||
hidden_states = layer(hidden_states, None, rotary_emb, temb, transformer_options=transformer_options)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
p = self.patch_size
|
||||
output = rearrange(hidden_states[:, -img_len:], 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h=H_padded // p, w=W_padded // p, p1=p, p2=p)[:, :, :H, :W]
|
||||
|
||||
return -output
|
||||
25
comfy/ldm/colormap.py
Normal file
25
comfy/ldm/colormap.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""Colormap utilities for depth and geometry visualisation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def turbo(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Anton Mikhailov polynomial approximation of the Turbo colormap.
|
||||
|
||||
Args:
|
||||
x: Float tensor with values in [0, 1].
|
||||
|
||||
Returns:
|
||||
RGB tensor of the same shape as ``x`` with a trailing size-3 dimension.
|
||||
"""
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x2 = x * x
|
||||
x3 = x2 * x
|
||||
x4 = x2 * x2
|
||||
x5 = x4 * x
|
||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
||||
@ -515,7 +515,7 @@ class Block(nn.Module):
|
||||
h=H,
|
||||
w=W,
|
||||
)
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_self_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
|
||||
|
||||
def _x_fn(
|
||||
_x_B_T_H_W_D: torch.Tensor,
|
||||
@ -548,7 +548,7 @@ class Block(nn.Module):
|
||||
shift_cross_attn_B_T_1_1_D,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D
|
||||
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_cross_attn_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
|
||||
|
||||
normalized_x_B_T_H_W_D = _fn(
|
||||
x_B_T_H_W_D,
|
||||
@ -557,7 +557,7 @@ class Block(nn.Module):
|
||||
shift_mlp_B_T_1_1_D,
|
||||
)
|
||||
result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype))
|
||||
x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype)
|
||||
x_B_T_H_W_D = torch.addcmul(x_B_T_H_W_D, gate_mlp_B_T_1_1_D.to(residual_dtype), result_B_T_H_W_D.to(residual_dtype))
|
||||
return x_B_T_H_W_D
|
||||
|
||||
|
||||
|
||||
177
comfy/ldm/depth_anything_3/camera.py
Normal file
177
comfy/ldm/depth_anything_3/camera.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Camera-token encoder and decoder for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from .transform import affine_inverse, extri_intri_to_pose_encoding
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Building blocks (mirror depth_anything_3.model.utils.{attention,block})
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Mlp(nn.Module):
|
||||
"""Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``."""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, bias=True, device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(F.gelu(self.fc1(x)))
|
||||
|
||||
|
||||
class _LayerScale(nn.Module):
|
||||
"""Per-channel learnable scaling. Matches upstream LayerScale."""
|
||||
|
||||
def __init__(self, dim, *, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gamma.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
class _Attention(nn.Module):
|
||||
""" Self-attention with fused QKV projection. Mirrors upstream utils.attention.Attention;
|
||||
Layout matches the HF safetensors (attn.qkv.{weight,bias} and attn.proj.{weight,bias})."""
|
||||
|
||||
def __init__(self, dim, num_heads, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||
q, k, v = qkv.unbind(2) # each (B, N, C)
|
||||
attn_fn = optimized_attention_for_device(x.device, small_input=True)
|
||||
out = attn_fn(q, k, v, heads=self.num_heads)
|
||||
return self.proj(out)
|
||||
|
||||
|
||||
class _Block(nn.Module):
|
||||
"""Pre-norm transformer block with LayerScale. Used by :class:CameraEnc. Layout follows upstream utils.block.Block."""
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.attn = _Attention(dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
||||
self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.ls1(self.attn(self.norm1(x)))
|
||||
x = x + self.ls2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class CameraEnc(nn.Module):
|
||||
"""Encode per-view (extrinsics, intrinsics) into a camera token.
|
||||
|
||||
Maps a 9-D pose-encoding vector through a small MLP up to the backbone's
|
||||
``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output
|
||||
has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start``
|
||||
of the DINOv2 backbone in place of the cls token.
|
||||
|
||||
Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_out: int = 1024,
|
||||
dim_in: int = 9,
|
||||
trunk_depth: int = 4,
|
||||
target_dim: int = 9,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
init_values: float = 0.01,
|
||||
*,
|
||||
device=None, dtype=None, operations=None,
|
||||
**_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.target_dim = target_dim
|
||||
self.trunk_depth = trunk_depth
|
||||
self.trunk = nn.Sequential(*[
|
||||
_Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(trunk_depth)
|
||||
])
|
||||
self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
|
||||
self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
|
||||
self.pose_branch = _Mlp(
|
||||
in_features=dim_in,
|
||||
hidden_features=dim_out // 2,
|
||||
out_features=dim_out,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor,
|
||||
image_size_hw) -> torch.Tensor:
|
||||
"""Encode camera parameters into ``(B, S, dim_out)`` tokens."""
|
||||
c2ws = affine_inverse(extrinsics)
|
||||
pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw)
|
||||
tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype))
|
||||
tokens = self.token_norm(tokens)
|
||||
tokens = self.trunk(tokens)
|
||||
tokens = self.trunk_norm(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
class CameraDec(nn.Module):
|
||||
"""Decode the final cam token into a 9-D pose encoding.
|
||||
|
||||
Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is
|
||||
always predicted by the network; the quaternion and FoV can either be
|
||||
predicted or supplied via ``camera_encoding`` (used at training time
|
||||
when GT cameras are available -- not exercised at inference here).
|
||||
|
||||
Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int = 1536,
|
||||
*, device=None, dtype=None, operations=None, **_kwargs):
|
||||
super().__init__()
|
||||
d = dim_in
|
||||
self.backbone = nn.Sequential(
|
||||
operations.Linear(d, d, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
operations.Linear(d, d, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype)
|
||||
self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype)
|
||||
self.fc_fov = nn.Sequential(
|
||||
operations.Linear(d, 2, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, feat: torch.Tensor,
|
||||
camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor:
|
||||
"""Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc."""
|
||||
B, N = feat.shape[:2]
|
||||
feat = feat.reshape(B * N, -1)
|
||||
feat = self.backbone(feat)
|
||||
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
|
||||
if camera_encoding is None:
|
||||
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
|
||||
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
|
||||
else:
|
||||
out_qvec = camera_encoding[..., 3:7]
|
||||
out_fov = camera_encoding[..., -2:]
|
||||
return torch.cat([out_t, out_qvec, out_fov], dim=-1)
|
||||
489
comfy/ldm/depth_anything_3/dpt.py
Normal file
489
comfy/ldm/depth_anything_3/dpt.py
Normal file
@ -0,0 +1,489 @@
|
||||
"""DPT / DualDPT heads for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Permute(nn.Module):
|
||||
def __init__(self, dims: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(*self.dims)
|
||||
|
||||
|
||||
def _custom_interpolate(
|
||||
x: torch.Tensor,
|
||||
size: Optional[Tuple[int, int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
mode: str = "bilinear",
|
||||
align_corners: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if size is None:
|
||||
assert scale_factor is not None
|
||||
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
||||
INT_MAX = 1610612736
|
||||
total = size[0] * size[1] * x.shape[0] * x.shape[1]
|
||||
if total > INT_MAX:
|
||||
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
|
||||
outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks]
|
||||
return torch.cat(outs, dim=0).contiguous()
|
||||
return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
||||
|
||||
|
||||
def _create_uv_grid(width: int, height: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
|
||||
"""Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span)."""
|
||||
diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5
|
||||
span_x = aspect_ratio / diag_factor
|
||||
span_y = 1.0 / diag_factor
|
||||
left_x = -span_x * (width - 1) / width
|
||||
right_x = span_x * (width - 1) / width
|
||||
top_y = -span_y * (height - 1) / height
|
||||
bottom_y = span_y * (height - 1) / height
|
||||
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
||||
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
||||
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
||||
return torch.stack((uu, vv), dim=-1) # (H, W, 2)
|
||||
|
||||
|
||||
def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor:
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
|
||||
omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0))
|
||||
pos = pos.reshape(-1)
|
||||
out = torch.einsum("m,d->md", pos, omega)
|
||||
return torch.cat([out.sin(), out.cos()], dim=1).float()
|
||||
|
||||
|
||||
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100.0) -> torch.Tensor:
|
||||
H, W, _ = pos_grid.shape
|
||||
pos_flat = pos_grid.reshape(-1, 2)
|
||||
emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
|
||||
emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)
|
||||
emb = torch.cat([emb_x, emb_y], dim=-1)
|
||||
return emb.view(H, W, embed_dim)
|
||||
|
||||
|
||||
def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
||||
"""Stateless UV positional embedding added to a feature map (B, C, h, w)."""
|
||||
pw, ph = x.shape[-1], x.shape[-2]
|
||||
pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
||||
pe = _position_grid_to_embed(pe, x.shape[1]) * ratio
|
||||
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype)
|
||||
return x + pe
|
||||
|
||||
|
||||
def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor:
|
||||
act = (activation or "linear").lower()
|
||||
if act == "exp":
|
||||
return torch.exp(x)
|
||||
if act == "expp1":
|
||||
return torch.exp(x) + 1
|
||||
if act == "expm1":
|
||||
return torch.expm1(x)
|
||||
if act == "relu":
|
||||
return torch.relu(x)
|
||||
if act == "sigmoid":
|
||||
return torch.sigmoid(x)
|
||||
if act == "softplus":
|
||||
return F.softplus(x)
|
||||
if act == "tanh":
|
||||
return torch.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Fusion building blocks
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
def __init__(self, features: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
|
||||
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
|
||||
self.activation = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
def __init__(self, features: int, has_residual: bool = True, align_corners: bool = True, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.align_corners = align_corners
|
||||
self.has_residual = has_residual
|
||||
if has_residual:
|
||||
self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.resConfUnit1 = None
|
||||
self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
|
||||
y = xs[0]
|
||||
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
|
||||
y = y + self.resConfUnit1(xs[1])
|
||||
y = self.resConfUnit2(y)
|
||||
if size is None:
|
||||
up_kwargs = {"scale_factor": 2.0}
|
||||
else:
|
||||
up_kwargs = {"size": size}
|
||||
y = _custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
|
||||
y = self.out_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class _Scratch(nn.Module):
|
||||
"""Container that mirrors upstream ``scratch`` attribute layout."""
|
||||
|
||||
|
||||
def _make_scratch(in_shape: List[int], out_shape: int, device=None, dtype=None, operations=None) -> _Scratch:
|
||||
scratch = _Scratch()
|
||||
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_fusion_block(features: int, has_residual: bool = True, device=None, dtype=None, operations=None) -> FeatureFusionBlock:
|
||||
return FeatureFusionBlock(features, has_residual=has_residual, align_corners=True, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DPT (single head + optional sky head) -- used by DA3Mono/Metric
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DPT(nn.Module):
|
||||
"""Single-head DPT used by DA3Mono-Large and DA3Metric-Large."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 1,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = False,
|
||||
down_ratio: int = 1,
|
||||
head_name: str = "depth",
|
||||
use_sky_head: bool = True,
|
||||
sky_name: str = "sky",
|
||||
sky_activation: str = "relu",
|
||||
norm_type: str = "idt",
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.head_main = head_name
|
||||
self.sky_name = sky_name
|
||||
self.out_dim = output_dim
|
||||
self.has_conf = output_dim > 1
|
||||
self.use_sky_head = use_sky_head
|
||||
self.sky_activation = sky_activation
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
|
||||
if norm_type == "layer":
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
if self.use_sky_head:
|
||||
self.scratch.sky_output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
# feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C)
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
|
||||
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
|
||||
out = self.scratch.refinenet1(out, l1_rn)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
fused = self.scratch.output_conv1(out)
|
||||
fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
fused = _add_pos_embed(fused, W, H)
|
||||
feat = fused
|
||||
|
||||
main_logits = self.scratch.output_conv2(feat)
|
||||
outs = {}
|
||||
if self.has_conf:
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1])
|
||||
outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:])
|
||||
else:
|
||||
pred = _apply_activation(main_logits, self.activation)
|
||||
outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:])
|
||||
|
||||
if self.use_sky_head:
|
||||
sky_logits = self.scratch.sky_output_conv2(feat)
|
||||
if self.sky_activation.lower() == "sigmoid":
|
||||
sky = torch.sigmoid(sky_logits)
|
||||
elif self.sky_activation.lower() == "relu":
|
||||
sky = F.relu(sky_logits)
|
||||
else:
|
||||
sky = sky_logits
|
||||
outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:])
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DualDPT(nn.Module):
|
||||
"""Two-head DPT used by DA3-Small / DA3-Base."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 2,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = True,
|
||||
down_ratio: int = 1,
|
||||
aux_pyramid_levels: int = 4,
|
||||
aux_out1_conv_num: int = 5,
|
||||
head_names: Tuple[str, str] = ("depth", "ray"),
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.aux_levels = aux_pyramid_levels
|
||||
self.aux_out1_conv_num = aux_out1_conv_num
|
||||
self.head_main, self.head_aux = head_names
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
# Toggle the auxiliary ray branch at runtime. Default off (mono path).
|
||||
# DepthAnything3Net flips this on when running multi-view + ray-pose.
|
||||
self.enable_aux: bool = False
|
||||
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
|
||||
# Main fusion chain
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
# Auxiliary fusion chain (separate copies)
|
||||
self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
# Main head neck + final projection
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
# Aux pre-head per level (multi-level pyramid)
|
||||
self.scratch.output_conv1_aux = nn.ModuleList([
|
||||
self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
# Aux final projection per level (includes LayerNorm permute path).
|
||||
ln_seq = [Permute((0, 2, 3, 1)),
|
||||
operations.LayerNorm(head_features_2, device=device, dtype=dtype),
|
||||
Permute((0, 3, 1, 2))]
|
||||
self.scratch.output_conv2_aux = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
*ln_seq,
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential:
|
||||
# aux_out1_conv_num=5 in all Apache-2.0 variants.
|
||||
return nn.Sequential(
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
# Main pyramid (output_conv1 is applied inside the upstream `_fuse`,
|
||||
# before interpolation -- replicate that order here).
|
||||
m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
|
||||
aux_pyr = [a4]
|
||||
m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:]))
|
||||
m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:]))
|
||||
m = self.scratch.refinenet1(m, l1_rn)
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn))
|
||||
m = self.scratch.output_conv1(m)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
m = _add_pos_embed(m, W, H)
|
||||
main_logits = self.scratch.output_conv2(m)
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
depth_pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
depth_conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
|
||||
outs = {
|
||||
self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]),
|
||||
f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]),
|
||||
}
|
||||
|
||||
if self.enable_aux:
|
||||
# Auxiliary "ray" head (multi-level inside) -- only the last level
|
||||
# is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``:
|
||||
# each aux pyramid level goes through ``output_conv1_aux[i]``
|
||||
# (5-layer conv stack that ends at ``features // 2`` channels),
|
||||
# then the last level optionally gets a pos-embed and finally
|
||||
# ``output_conv2_aux[-1]``.
|
||||
aux_processed = [
|
||||
self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr)
|
||||
]
|
||||
last_aux = aux_processed[-1]
|
||||
if self.pos_embed:
|
||||
last_aux = _add_pos_embed(last_aux, W, H)
|
||||
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
|
||||
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
|
||||
# Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation.
|
||||
aux_pred = fmap_last[..., :-1]
|
||||
aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation)
|
||||
outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:])
|
||||
outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:])
|
||||
|
||||
return outs
|
||||
236
comfy/ldm/depth_anything_3/model.py
Normal file
236
comfy/ldm/depth_anything_3/model.py
Normal file
@ -0,0 +1,236 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.image_encoders.dino2 import Dinov2Model
|
||||
|
||||
from .camera import CameraDec, CameraEnc
|
||||
from .dpt import DPT, DualDPT
|
||||
from .ray_pose import get_extrinsic_from_camray
|
||||
from .transform import affine_inverse, pose_encoding_to_extri_intri
|
||||
|
||||
|
||||
_HEAD_REGISTRY = {
|
||||
"dpt": DPT,
|
||||
"dualdpt": DualDPT,
|
||||
}
|
||||
|
||||
|
||||
# Backbone presets (mirror the upstream DINOv2 ViT variants).
|
||||
_BACKBONE_PRESETS = {
|
||||
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
|
||||
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
|
||||
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
|
||||
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
|
||||
}
|
||||
|
||||
|
||||
def _build_backbone_config(
|
||||
backbone_name: str,
|
||||
*,
|
||||
alt_start: int,
|
||||
qknorm_start: int,
|
||||
rope_start: int,
|
||||
cat_token: bool,
|
||||
) -> dict:
|
||||
if backbone_name not in _BACKBONE_PRESETS:
|
||||
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
|
||||
cfg = dict(_BACKBONE_PRESETS[backbone_name])
|
||||
cfg.update(dict(
|
||||
layer_norm_eps=1e-6,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
# No mask_token in DA3 weights; omit param to avoid load warnings.
|
||||
use_mask_token=False,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
rope_freq=100.0,
|
||||
))
|
||||
return cfg
|
||||
|
||||
|
||||
class DepthAnything3Net(nn.Module):
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# --- Backbone ---
|
||||
backbone_name: str = "vitl",
|
||||
out_layers: Sequence[int] = (4, 11, 17, 23),
|
||||
alt_start: int = -1,
|
||||
qknorm_start: int = -1,
|
||||
rope_start: int = -1,
|
||||
cat_token: bool = False,
|
||||
# --- Head ---
|
||||
head_type: str = "dpt", # dpt or dualdpt
|
||||
head_dim_in: int = 1024,
|
||||
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
|
||||
head_features: int = 256,
|
||||
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
head_use_sky_head: bool = True, # ignored by DualDPT
|
||||
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
|
||||
# --- Camera (multi-view) ---
|
||||
has_cam_enc: bool = False,
|
||||
has_cam_dec: bool = False,
|
||||
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
|
||||
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
|
||||
# ComfyUI plumbing
|
||||
device=None, dtype=None, operations=None,
|
||||
**_ignored,
|
||||
):
|
||||
super().__init__()
|
||||
head_cls = _HEAD_REGISTRY[head_type.lower()]
|
||||
self.head_type = head_type.lower()
|
||||
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
|
||||
self.has_conf = head_output_dim > 1
|
||||
self.out_layers = list(out_layers)
|
||||
|
||||
backbone_cfg = _build_backbone_config(
|
||||
backbone_name,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
)
|
||||
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
|
||||
|
||||
head_kwargs = dict(
|
||||
dim_in=head_dim_in,
|
||||
patch_size=self.PATCH_SIZE,
|
||||
output_dim=head_output_dim,
|
||||
features=head_features,
|
||||
out_channels=tuple(head_out_channels),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
if self.head_type == "dpt":
|
||||
head_kwargs.update(
|
||||
use_sky_head=head_use_sky_head,
|
||||
pos_embed=(False if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
else: # dualdpt
|
||||
head_kwargs.update(
|
||||
pos_embed=(True if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
self.head = head_cls(**head_kwargs)
|
||||
|
||||
# Built only if checkpoint has weights; cam_enc output dim == embed_dim.
|
||||
embed_dim = backbone_cfg["hidden_size"]
|
||||
if has_cam_enc:
|
||||
self.cam_enc = CameraEnc(
|
||||
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
|
||||
num_heads=max(1, embed_dim // 64),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
else:
|
||||
self.cam_enc = None
|
||||
if has_cam_dec:
|
||||
default_dim = embed_dim * (2 if cat_token else 1)
|
||||
self.cam_dec = CameraDec(
|
||||
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
else:
|
||||
self.cam_dec = None
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
extrinsics: Optional[torch.Tensor] = None,
|
||||
intrinsics: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
use_ray_pose: bool = False,
|
||||
ref_view_strategy: str = "saddle_balanced",
|
||||
export_feat_layers: Optional[Sequence[int]] = None,
|
||||
**_unused,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run depth and optionally pose prediction."""
|
||||
if image.ndim == 4:
|
||||
image = image.unsqueeze(1) # (B, 1, 3, H, W)
|
||||
assert image.ndim == 5 and image.shape[2] == 3, \
|
||||
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
|
||||
|
||||
B, S, _, H, W = image.shape
|
||||
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
|
||||
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
|
||||
|
||||
# Camera-token preparation (multi-view path).
|
||||
cam_token = None
|
||||
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
|
||||
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
|
||||
|
||||
# Toggle aux ray output on/off depending on what the caller asked for.
|
||||
if isinstance(self.head, DualDPT):
|
||||
self.head.enable_aux = bool(use_ray_pose)
|
||||
|
||||
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
|
||||
image, self.out_layers, cam_token=cam_token,
|
||||
ref_view_strategy=ref_view_strategy,
|
||||
export_feat_layers=export_feat_layers,
|
||||
)
|
||||
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
|
||||
|
||||
# Pose prediction.
|
||||
out: Dict[str, torch.Tensor] = {}
|
||||
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
|
||||
ray = head_out["ray"]
|
||||
ray_conf = head_out["ray_conf"]
|
||||
extr_c2w, focal, pp = get_extrinsic_from_camray(
|
||||
ray, ray_conf, ray.shape[-3], ray.shape[-2],
|
||||
)
|
||||
# Match the upstream output: w2c, drop the homogeneous row.
|
||||
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
|
||||
# Build pixel-space intrinsics from the normalised focal/pp output.
|
||||
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
|
||||
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
|
||||
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
|
||||
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
|
||||
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
|
||||
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
|
||||
out["extrinsics"] = extr_w2c
|
||||
out["intrinsics"] = intr
|
||||
elif self.cam_dec is not None and S > 1:
|
||||
# Decode the cam-token of the final out_layer into a pose encoding.
|
||||
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
|
||||
pose_enc = self.cam_dec(cam_feat)
|
||||
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
|
||||
# Match the upstream output convention: w2c (world->camera), 3x4.
|
||||
c2w_4x4 = torch.cat([
|
||||
c2w_3x4,
|
||||
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
|
||||
.view(1, 1, 1, 4).expand(B, S, 1, 4),
|
||||
], dim=-2)
|
||||
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
|
||||
out["intrinsics"] = intr
|
||||
|
||||
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
|
||||
# per-image consumer keeps its (B*S, H, W) interface.
|
||||
for k, v in head_out.items():
|
||||
if k in ("ray", "ray_conf"):
|
||||
# Keep multi-view shape for downstream pose work.
|
||||
out[k] = v
|
||||
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
|
||||
out[k] = v.reshape(B * S, *v.shape[2:])
|
||||
else:
|
||||
out[k] = v
|
||||
|
||||
if export_feat_layers:
|
||||
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
|
||||
return out
|
||||
|
||||
def _reshape_aux_features(self, aux_feats, H: int, W: int):
|
||||
"""Reshape (B, S, N, C) aux features into (B, S, h_p, w_p, C)."""
|
||||
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
|
||||
out = []
|
||||
for f in aux_feats:
|
||||
B, S, N, C = f.shape
|
||||
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
|
||||
out.append(f.reshape(B, S, ph, pw, C))
|
||||
return out
|
||||
128
comfy/ldm/depth_anything_3/preprocess.py
Normal file
128
comfy/ldm/depth_anything_3/preprocess.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""Input/output preprocessing helpers for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
# ImageNet normalization constants used during DA3 training.
|
||||
_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
|
||||
_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
|
||||
down = (x // patch) * patch
|
||||
up = down + patch
|
||||
return up if abs(up - x) <= abs(x - down) else down
|
||||
|
||||
|
||||
def compute_target_size(orig_h: int, orig_w: int, process_res: int, method: str = "upper_bound_resize") -> Tuple[int, int]:
|
||||
"""Compute (target_h, target_w) for a single image.
|
||||
upper_bound_resize: scale longest side to process_res, then round each dim to nearest multiple of 14 (default upstream method).
|
||||
lower_bound_resize: scale shortest side to process_res, then round."""
|
||||
|
||||
if method == "upper_bound_resize":
|
||||
longest = max(orig_h, orig_w)
|
||||
scale = process_res / float(longest)
|
||||
elif method == "lower_bound_resize":
|
||||
shortest = min(orig_h, orig_w)
|
||||
scale = process_res / float(shortest)
|
||||
else:
|
||||
raise ValueError(f"Unsupported process_res_method: {method}")
|
||||
|
||||
new_w = max(1, _round_to_patch(int(round(orig_w * scale))))
|
||||
new_h = max(1, _round_to_patch(int(round(orig_h * scale))))
|
||||
return new_h, new_w
|
||||
|
||||
|
||||
def preprocess_image(image: torch.Tensor, process_res: int = 504, method: str = "upper_bound_resize") -> torch.Tensor:
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
B, H, W, _ = image.shape
|
||||
target_h, target_w = compute_target_size(H, W, process_res, method)
|
||||
|
||||
# (B, H, W, 3) -> (B, 3, H, W)
|
||||
x = image.movedim(-1, 1).contiguous()
|
||||
if (target_h, target_w) != (H, W):
|
||||
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale).
|
||||
# Lanczos in ``common_upscale`` is anti-aliased and produces the
|
||||
# closest pixel-wise match in a sweep across {bilinear, bicubic,
|
||||
# area, lanczos, bislerp}. Used in both directions for simplicity.
|
||||
x = comfy.utils.common_upscale(x.float(), target_w, target_h, "lanczos", "disabled",)
|
||||
x = x.clamp(0.0, 1.0)
|
||||
|
||||
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
x = (x - mean) / std
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Output post-processing (sky-aware clipping for Mono/Metric variants)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
|
||||
"""Boolean mask: True for non-sky pixels (sky probability < threshold)."""
|
||||
return sky_prediction < threshold
|
||||
|
||||
|
||||
def apply_sky_aware_clip(depth: torch.Tensor, sky: torch.Tensor, threshold: float = 0.3, quantile: float = 0.99) -> torch.Tensor:
|
||||
"""Clips sky regions to the 99th percentile of non-sky depth. Returns a new depth tensor."""
|
||||
non_sky = compute_non_sky_mask(sky, threshold=threshold)
|
||||
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
|
||||
return depth.clone()
|
||||
|
||||
non_sky_depth = depth[non_sky]
|
||||
if non_sky_depth.numel() > 100_000:
|
||||
idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device)
|
||||
sampled = non_sky_depth[idx]
|
||||
else:
|
||||
sampled = non_sky_depth
|
||||
|
||||
max_depth = torch.quantile(sampled, quantile)
|
||||
out = depth.clone()
|
||||
out[~non_sky] = max_depth
|
||||
return out
|
||||
|
||||
|
||||
def normalize_depth_v2_style(depth: torch.Tensor, sky: torch.Tensor | None = None, low_quantile: float = 0.01, high_quantile: float = 0.99) -> torch.Tensor:
|
||||
"""V2-style normalization computes percentile bounds over non-sky pixels (when available), then maps depth into [0, 1] with near = white (1.0)."""
|
||||
if sky is not None:
|
||||
mask = compute_non_sky_mask(sky)
|
||||
if mask.any():
|
||||
valid = depth[mask]
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
|
||||
if valid.numel() > 100_000:
|
||||
idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device)
|
||||
sample = valid[idx]
|
||||
else:
|
||||
sample = valid
|
||||
|
||||
lo = torch.quantile(sample, low_quantile)
|
||||
hi = torch.quantile(sample, high_quantile)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
# Nearer pixels are brighter (1.0)
|
||||
norm = 1.0 - norm
|
||||
if sky is not None:
|
||||
# Sky pixels become black (far / unknown)
|
||||
sky_mask = ~compute_non_sky_mask(sky)
|
||||
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
|
||||
return norm
|
||||
|
||||
|
||||
def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor:
|
||||
"""Simple per-frame min/max normalization with near=1.0 convention."""
|
||||
lo = depth.amin(dim=(-2, -1), keepdim=True)
|
||||
hi = depth.amax(dim=(-2, -1), keepdim=True)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
272
comfy/ldm/depth_anything_3/ray_pose.py
Normal file
272
comfy/ldm/depth_anything_3/ray_pose.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops.
|
||||
|
||||
|
||||
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Decompose A = Q @ L with Q orthogonal and L lower-triangular.
|
||||
Implemented in terms of QR by reversing the columns/rows; the standard
|
||||
trick from the upstream reference. Inputs A are (3, 3)."""
|
||||
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device, dtype=A.dtype)
|
||||
A_tilde = A @ P
|
||||
# CUDA QR is not implemented for fp16/bf16; upcast just for this call.
|
||||
Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float())
|
||||
Q_tilde = Q_tilde.to(A.dtype)
|
||||
R_tilde = R_tilde.to(A.dtype)
|
||||
Q = Q_tilde @ P
|
||||
L = P @ R_tilde @ P
|
||||
d = torch.diag(L)
|
||||
sign = torch.sign(d)
|
||||
Q = Q * sign[None, :] # scale columns of Q
|
||||
L = L * sign[:, None] # scale rows of L
|
||||
return Q, L
|
||||
|
||||
|
||||
def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Weighted-LSQ + RANSAC homography (batched)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq(src_pts: torch.Tensor, dst_pts: torch.Tensor, confident_weight: torch.Tensor,) -> torch.Tensor:
|
||||
"""Solve a single H with weighted least-squares (DLT)."""
|
||||
N = src_pts.shape[0]
|
||||
if N < 4:
|
||||
raise ValueError("At least 4 points are required to compute a homography.")
|
||||
w = confident_weight.sqrt().unsqueeze(1) # (N, 1)
|
||||
x = src_pts[:, 0:1]
|
||||
y = src_pts[:, 1:2]
|
||||
u = dst_pts[:, 0:1]
|
||||
v = dst_pts[:, 1:2]
|
||||
zeros = torch.zeros_like(x)
|
||||
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
|
||||
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
|
||||
A = torch.cat([A1, A2], dim=0) # (2N, 9)
|
||||
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
||||
_, _, Vh = torch.linalg.svd(A.float())
|
||||
Vh = Vh.to(A.dtype)
|
||||
H = Vh[-1].reshape(3, 3)
|
||||
return H / H[-1, -1]
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq_batched(src_pts_batch: torch.Tensor, dst_pts_batch: torch.Tensor, confident_weight_batch: torch.Tensor) -> torch.Tensor:
|
||||
"""Batched DLT solver. Inputs (B, K, 2) / (B, K); output (B, 3, 3)."""
|
||||
B, K, _ = src_pts_batch.shape
|
||||
w = confident_weight_batch.sqrt().unsqueeze(2)
|
||||
x = src_pts_batch[:, :, 0:1]
|
||||
y = src_pts_batch[:, :, 1:2]
|
||||
u = dst_pts_batch[:, :, 0:1]
|
||||
v = dst_pts_batch[:, :, 1:2]
|
||||
zeros = torch.zeros_like(x)
|
||||
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
|
||||
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
|
||||
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
|
||||
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
||||
_, _, Vh = torch.linalg.svd(A.float())
|
||||
Vh = Vh.to(A.dtype)
|
||||
H = Vh[:, -1].reshape(B, 3, 3)
|
||||
return H / H[:, 2:3, 2:3]
|
||||
|
||||
|
||||
def _ransac_find_homography_weighted_batched(
|
||||
src_pts: torch.Tensor, # (B, N, 2)
|
||||
dst_pts: torch.Tensor, # (B, N, 2)
|
||||
confident_weight: torch.Tensor, # (B, N)
|
||||
n_sample: int,
|
||||
n_iter: int = 100,
|
||||
reproj_threshold: float = 3.0,
|
||||
num_sample_for_ransac: int = 8,
|
||||
max_inlier_num: int = 10000,
|
||||
rand_sample_iters_idx: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Batched weighted-RANSAC homography estimator. Returns (B, 3, 3) homography matrices."""
|
||||
B, N, _ = src_pts.shape
|
||||
assert N >= 4
|
||||
device = src_pts.device
|
||||
|
||||
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1)
|
||||
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
|
||||
|
||||
if rand_sample_iters_idx is None:
|
||||
rand_sample_iters_idx = torch.stack(
|
||||
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac]
|
||||
for _ in range(n_iter)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k)
|
||||
b_idx = (
|
||||
torch.arange(B, device=device)
|
||||
.view(B, 1, 1)
|
||||
.expand(B, n_iter, num_sample_for_ransac)
|
||||
)
|
||||
src_b = src_pts[b_idx, rand_idx]
|
||||
dst_b = dst_pts[b_idx, rand_idx]
|
||||
w_b = confident_weight[b_idx, rand_idx]
|
||||
|
||||
cB, cN = src_b.shape[:2]
|
||||
H_batch = _find_homography_weighted_lsq_batched(
|
||||
src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1),
|
||||
).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3)
|
||||
|
||||
src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2)
|
||||
proj = torch.bmm(
|
||||
src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3),
|
||||
H_batch.reshape(-1, 3, 3).transpose(1, 2),
|
||||
) # (B*n_iter, N, 3)
|
||||
proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2)
|
||||
err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N)
|
||||
inlier_mask = err < reproj_threshold
|
||||
score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2)
|
||||
best_idx = torch.argmax(score, dim=1)
|
||||
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx]
|
||||
|
||||
# Refit with the inlier set (per-batch, since the inlier counts vary).
|
||||
H_inlier_list = []
|
||||
for b in range(B):
|
||||
mask = best_inlier_mask[b]
|
||||
in_src = src_pts[b][mask]
|
||||
in_dst = dst_pts[b][mask]
|
||||
in_w = confident_weight[b][mask]
|
||||
if in_src.shape[0] < 4:
|
||||
# Fall back to identity when RANSAC fails to find enough inliers.
|
||||
H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype))
|
||||
continue
|
||||
sorted_w = torch.argsort(in_w, descending=True)
|
||||
if len(sorted_w) > max_inlier_num:
|
||||
keep = max(int(len(sorted_w) * 0.95), max_inlier_num)
|
||||
sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]]
|
||||
H_inlier_list.append(
|
||||
_find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w])
|
||||
)
|
||||
return torch.stack(H_inlier_list, dim=0)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Camera-ray utilities
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _unproject_identity(num_y: int, num_x: int, B: int, S: int, device, dtype) -> torch.Tensor:
|
||||
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane."""
|
||||
dx = 1.0 / num_x
|
||||
dy = 1.0 / num_y
|
||||
# Centered camera-space coords directly (skip the K^-1 step since it's
|
||||
# just a translation by -1 on x and y when K is identity-with-center=1).
|
||||
y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype)
|
||||
x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype)
|
||||
yy, xx = torch.meshgrid(y, x, indexing="ij")
|
||||
grid = torch.stack((xx, yy), dim=-1) # (h, w, 2)
|
||||
grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2)
|
||||
return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
|
||||
|
||||
|
||||
def _camray_to_caminfo(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
|
||||
reproj_threshold: float = 0.2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Convert per-pixel camera rays to per-view (R, T, focal, principal)."""
|
||||
if confidence is None:
|
||||
confidence = torch.ones_like(camray[..., 0])
|
||||
B, S, h, w, _ = camray.shape
|
||||
device = camray.device
|
||||
dtype = camray.dtype
|
||||
|
||||
rays_target = camray[..., :3] # (B, S, h, w, 3)
|
||||
rays_origin = _unproject_identity(h, w, B, S, device, dtype)
|
||||
|
||||
# Flatten (B*S, h*w, *) for the RANSAC routine.
|
||||
rays_target = rays_target.flatten(0, 1).flatten(1, 2)
|
||||
rays_origin = rays_origin.flatten(0, 1).flatten(1, 2)
|
||||
weights = confidence.flatten(0, 1).flatten(1, 2).clone()
|
||||
|
||||
# Project to 2D in homogeneous form (the upstream calls this "perspective division").
|
||||
z_thresh = 1e-4
|
||||
mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh)
|
||||
weights = torch.where(mask, weights, torch.zeros_like(weights))
|
||||
src = rays_origin.clone()
|
||||
dst = rays_target.clone()
|
||||
src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0])
|
||||
src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1])
|
||||
dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0])
|
||||
dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1])
|
||||
src = src[..., :2]
|
||||
dst = dst[..., :2]
|
||||
|
||||
N = src.shape[1]
|
||||
n_iter = 100
|
||||
sample_ratio = 0.3
|
||||
num_sample_for_ransac = 8
|
||||
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
|
||||
rand_idx = torch.stack(
|
||||
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Chunk along the view axis to keep peak memory predictable.
|
||||
chunk = 2
|
||||
A_list = []
|
||||
for i in range(0, src.shape[0], chunk):
|
||||
A = _ransac_find_homography_weighted_batched(
|
||||
src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk],
|
||||
n_sample=n_sample, n_iter=n_iter,
|
||||
num_sample_for_ransac=num_sample_for_ransac,
|
||||
reproj_threshold=reproj_threshold,
|
||||
rand_sample_iters_idx=rand_idx,
|
||||
max_inlier_num=8000,
|
||||
)
|
||||
# Flip sign on dets that come out < 0 (so that the QL produces a
|
||||
# right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so
|
||||
# do the comparison in fp32.
|
||||
flip = torch.linalg.det(A.float()) < 0
|
||||
A = torch.where(flip[:, None, None], -A, A)
|
||||
A_list.append(A)
|
||||
A = torch.cat(A_list, dim=0) # (B*S, 3, 3)
|
||||
|
||||
R_list, f_list, pp_list = [], [], []
|
||||
for i in range(A.shape[0]):
|
||||
R, L = _ql_decomposition(A[i])
|
||||
L = L / L[2][2]
|
||||
f_list.append(torch.stack((L[0][0], L[1][1])))
|
||||
pp_list.append(torch.stack((L[2][0], L[2][1])))
|
||||
R_list.append(R)
|
||||
R = torch.stack(R_list).reshape(B, S, 3, 3)
|
||||
focal = torch.stack(f_list).reshape(B, S, 2)
|
||||
pp = torch.stack(pp_list).reshape(B, S, 2)
|
||||
|
||||
# Translation: confidence-weighted average of camray direction(s).
|
||||
cf = confidence.flatten(0, 1).flatten(1, 2)
|
||||
T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1)
|
||||
T = T / cf.sum(dim=-1, keepdim=True)
|
||||
T = T.reshape(B, S, 3)
|
||||
|
||||
# Match upstream output convention: focal -> 1/focal, pp + 1.
|
||||
return R, T, 1.0 / focal, pp + 1.0
|
||||
|
||||
|
||||
def get_extrinsic_from_camray(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
|
||||
patch_size_y: int,
|
||||
patch_size_x: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output."""
|
||||
if conf.ndim == 5 and conf.shape[-1] == 1:
|
||||
conf = conf.squeeze(-1)
|
||||
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
|
||||
extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4)
|
||||
homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
|
||||
homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4)
|
||||
extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4)
|
||||
return extr, focal, pp
|
||||
87
comfy/ldm/depth_anything_3/reference_view_selector.py
Normal file
87
comfy/ldm/depth_anything_3/reference_view_selector.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Reference-view selection for the multi-view path of Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
|
||||
|
||||
|
||||
# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``.
|
||||
# Reference selection only runs when there are at least this many views.
|
||||
THRESH_FOR_REF_SELECTION: int = 3
|
||||
|
||||
|
||||
def select_reference_view(x: torch.Tensor, strategy: RefViewStrategy = "saddle_balanced") -> torch.Tensor:
|
||||
"""Pick a reference view index per batch element."""
|
||||
B, S, _, _ = x.shape
|
||||
if S <= 1:
|
||||
return torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
if strategy == "first":
|
||||
return torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
if strategy == "middle":
|
||||
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
|
||||
|
||||
# Feature-based strategies: normalised cls/cam token per view.
|
||||
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C)
|
||||
|
||||
if strategy == "saddle_balanced":
|
||||
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S)
|
||||
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
|
||||
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S)
|
||||
feat_norm = x[:, :, 0].norm(dim=-1) # (B,S)
|
||||
feat_var = img_class_feat.var(dim=-1) # (B,S)
|
||||
|
||||
def _normalize(metric):
|
||||
mn = metric.min(dim=1, keepdim=True).values
|
||||
mx = metric.max(dim=1, keepdim=True).values
|
||||
return (metric - mn) / (mx - mn + 1e-8)
|
||||
|
||||
sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var)
|
||||
balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs()
|
||||
return balance.argmin(dim=1)
|
||||
|
||||
if strategy == "saddle_sim_range":
|
||||
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))
|
||||
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
|
||||
sim_max = sim_no_diag.max(dim=-1).values
|
||||
sim_min = sim_no_diag.min(dim=-1).values
|
||||
return (sim_max - sim_min).argmax(dim=1)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown reference view selection strategy: {strategy!r}. "
|
||||
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
|
||||
)
|
||||
|
||||
|
||||
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Reorder x so the reference view is at position 0 in axis S."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
b_idx_exp = b_idx.unsqueeze(1)
|
||||
reorder = torch.where(
|
||||
(positions > 0) & (positions <= b_idx_exp),
|
||||
positions - 1,
|
||||
positions,
|
||||
)
|
||||
reorder[:, 0] = b_idx
|
||||
batch = torch.arange(B, device=x.device).unsqueeze(1)
|
||||
return x[batch, reorder]
|
||||
|
||||
|
||||
def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse of reorder_by_reference."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
b_idx_exp = b_idx.unsqueeze(1)
|
||||
restore = torch.where(target_positions < b_idx_exp, target_positions + 1, target_positions)
|
||||
restore = torch.scatter(restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp))
|
||||
batch = torch.arange(B, device=x.device).unsqueeze(1)
|
||||
return x[batch, restore]
|
||||
160
comfy/ldm/depth_anything_3/transform.py
Normal file
160
comfy/ldm/depth_anything_3/transform.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""Geometry / camera transform helpers for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Affine 4x4 helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def as_homogeneous(ext: torch.Tensor) -> torch.Tensor:
|
||||
"""Promote (...,3,4) extrinsics to (...,4,4) homogeneous form. No-op when the input is already ``(...,4,4)``."""
|
||||
if ext.shape[-2:] == (4, 4):
|
||||
return ext
|
||||
if ext.shape[-2:] == (3, 4):
|
||||
ones = torch.zeros_like(ext[..., :1, :4])
|
||||
ones[..., 0, 3] = 1.0
|
||||
return torch.cat([ext, ones], dim=-2)
|
||||
raise ValueError(f"Invalid affine shape: {ext.shape}")
|
||||
|
||||
|
||||
def affine_inverse(A: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse of an affine matrix ``[R|T; 0 0 0 1]``."""
|
||||
R = A[..., :3, :3]
|
||||
T = A[..., :3, 3:]
|
||||
P = A[..., 3:, :]
|
||||
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Quaternion <-> rotation matrix (xyzw / scalar-last)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
"""sqrt(max(0, x)) with a zero subgradient where x == 0."""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
if torch.is_grad_enabled():
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
else:
|
||||
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
||||
return ret
|
||||
|
||||
|
||||
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""Force the real part of a unit quaternion (xyzw) to be non-negative."""
|
||||
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
||||
|
||||
|
||||
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert quaternions (xyzw) to (...,3,3) rotation matrices."""
|
||||
i, j, k, r = torch.unbind(quaternions, -1)
|
||||
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
||||
o = torch.stack(
|
||||
(
|
||||
1 - two_s * (j * j + k * k),
|
||||
two_s * (i * j - k * r),
|
||||
two_s * (i * k + j * r),
|
||||
two_s * (i * j + k * r),
|
||||
1 - two_s * (i * i + k * k),
|
||||
two_s * (j * k - i * r),
|
||||
two_s * (i * k - j * r),
|
||||
two_s * (j * k + i * r),
|
||||
1 - two_s * (i * i + j * j),
|
||||
),
|
||||
-1,
|
||||
)
|
||||
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
||||
|
||||
|
||||
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert (...,3,3) rotation matrices to quaternions (xyzw)."""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
batch_dim = matrix.shape[:-2]
|
||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||
matrix.reshape(batch_dim + (9,)), dim=-1
|
||||
)
|
||||
|
||||
q_abs = _sqrt_positive_part(
|
||||
torch.stack(
|
||||
[
|
||||
1.0 + m00 + m11 + m22,
|
||||
1.0 + m00 - m11 - m22,
|
||||
1.0 - m00 + m11 - m22,
|
||||
1.0 - m00 - m11 + m22,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||
|
||||
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
|
||||
batch_dim + (4,)
|
||||
)
|
||||
# Reorder rijk -> xyzw (i.e. ijkr).
|
||||
out = out[..., [1, 2, 3, 0]]
|
||||
return standardize_quaternion(out)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Pose-encoding <-> extrinsics + intrinsics
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extri_intri_to_pose_encoding(extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Pack (extr, intr, image_size) into the 9-D pose-encoding vector.
|
||||
extrinsics: camera-to-world (c2w) (B,S,4,4) matrices,
|
||||
intrinsics: pixel-space (B,S,3,3) matrices,
|
||||
image_size_hw: is a (H, W) pair.
|
||||
"""
|
||||
R = extrinsics[..., :3, :3]
|
||||
T = extrinsics[..., :3, 3]
|
||||
quat = mat_to_quat(R)
|
||||
H, W = image_size_hw
|
||||
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
||||
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
||||
return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
||||
|
||||
|
||||
def pose_encoding_to_extri_intri(pose_encoding: torch.Tensor, image_size_hw: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Inverse of extri_intri_to_pose_encoding."""
|
||||
T = pose_encoding[..., :3]
|
||||
quat = pose_encoding[..., 3:7]
|
||||
fov_h = pose_encoding[..., 7]
|
||||
fov_w = pose_encoding[..., 8]
|
||||
# Normalize to unit quaternion. CameraDec outputs raw values; a near-zero
|
||||
# quaternion causes two_s = 2/norm² → inf in quat_to_mat → NaN extrinsics.
|
||||
quat = quat / quat.norm(dim=-1, keepdim=True).clamp(min=1e-6)
|
||||
R = quat_to_mat(quat)
|
||||
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
||||
H, W = image_size_hw
|
||||
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
|
||||
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
|
||||
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device, dtype=pose_encoding.dtype)
|
||||
intrinsics[..., 0, 0] = fx
|
||||
intrinsics[..., 1, 1] = fy
|
||||
intrinsics[..., 0, 2] = W / 2
|
||||
intrinsics[..., 1, 2] = H / 2
|
||||
intrinsics[..., 2, 2] = 1.0
|
||||
return extrinsics, intrinsics
|
||||
@ -106,11 +106,11 @@ class Ideogram4EmbedScalar(nn.Module):
|
||||
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, dtype):
|
||||
x = x.to(torch.float32)
|
||||
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
||||
emb = _sinusoidal_embedding(scaled, self.dim)
|
||||
emb = emb.to(self.mlp_in.weight.dtype)
|
||||
emb = emb.to(dtype)
|
||||
emb = F.silu(self.mlp_in(emb))
|
||||
return self.mlp_out(emb)
|
||||
|
||||
@ -161,7 +161,7 @@ class Ideogram4Transformer(nn.Module):
|
||||
x = x * output_image_mask
|
||||
h = self.input_proj(x) * output_image_mask
|
||||
|
||||
t_cond = self.t_embedding(t)
|
||||
t_cond = self.t_embedding(t, dtype=x.dtype)
|
||||
if t.dim() == 1:
|
||||
t_cond = t_cond.unsqueeze(1)
|
||||
adaln_input = F.silu(self.adaln_proj(t_cond))
|
||||
@ -174,7 +174,7 @@ class Ideogram4Transformer(nn.Module):
|
||||
llm = self.llm_cond_proj(llm) * text_mask
|
||||
h[:, :L_text] = h[:, :L_text] + llm
|
||||
|
||||
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long))
|
||||
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype)
|
||||
|
||||
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
|
||||
freqs_cis = precompute_freqs_cis(
|
||||
@ -235,7 +235,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
|
||||
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
|
||||
B = x_chunk.shape[0]
|
||||
device = x_chunk.device
|
||||
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
|
||||
img_tokens = self._img_to_tokens(x_chunk)
|
||||
L_img = img_tokens.shape[1]
|
||||
L_text = context_chunk.shape[1]
|
||||
L = L_text + L_img
|
||||
@ -268,7 +268,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
|
||||
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
|
||||
B = x_chunk.shape[0]
|
||||
device = x_chunk.device
|
||||
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
|
||||
img_tokens = self._img_to_tokens(x_chunk)
|
||||
L_img = img_tokens.shape[1]
|
||||
|
||||
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
|
||||
|
||||
290
comfy/ldm/krea2/model.py
Normal file
290
comfy/ldm/krea2/model.py
Normal file
@ -0,0 +1,290 @@
|
||||
"""Krea 2 (K2) — single-stream MMDiT.
|
||||
|
||||
Text tokens produced by a Qwen3-VL-4B 12-layer ``txtfusion`` adapter and patchified image tokens are
|
||||
concatenated into one sequence and run through ``layers`` shared transformer blocks with
|
||||
AdaLN-single modulation, GQA + per-head QK-norm + sigmoid-gated attention, SwiGLU MLP, and 3-axis RoPE.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.flux.layers import EmbedND, timestep_embedding
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""RMSNorm with the reference ``(1 + scale)`` weight convention (scale stored zero-centered)."""
|
||||
|
||||
def __init__(self, features: int, eps: float = 1e-5, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.scale = nn.Parameter(torch.empty(features, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
dtype = x.dtype
|
||||
weight = comfy.model_management.cast_to(self.scale, dtype=torch.float32, device=x.device) + 1.0
|
||||
return F.rms_norm(x.float(), (x.shape[-1],), weight=weight, eps=self.eps).to(dtype)
|
||||
|
||||
|
||||
class QKNorm(nn.Module):
|
||||
def __init__(self, dim: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.qnorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
|
||||
self.knorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, q, k):
|
||||
return self.qnorm(q), self.knorm(k)
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
def __init__(self, features: int, multiplier: int, bias: bool = False, multiple: int = 128,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
mlpdim = int(2 * features / 3) * multiplier
|
||||
mlpdim = multiple * ((mlpdim + multiple - 1) // multiple)
|
||||
self.gate = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
|
||||
self.up = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype)
|
||||
self.down = operations.Linear(mlpdim, features, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.down(F.silu(self.gate(x)).mul_(self.up(x)))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, kvheads: Optional[int] = None, bias: bool = False,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.kvheads = kvheads if kvheads is not None else heads
|
||||
self.headdim = dim // self.heads
|
||||
self.wq = operations.Linear(dim, self.headdim * self.heads, bias=bias, device=device, dtype=dtype)
|
||||
self.wk = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
|
||||
self.wv = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype)
|
||||
self.gate = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
|
||||
self.qknorm = QKNorm(self.headdim, device=device, dtype=dtype, operations=operations)
|
||||
self.wo = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, freqs=None, mask=None, transformer_options={}):
|
||||
q, k, v, gate = self.wq(x), self.wk(x), self.wv(x), self.gate(x)
|
||||
q = rearrange(q, "B L (H D) -> B H L D", H=self.heads)
|
||||
k = rearrange(k, "B L (H D) -> B H L D", H=self.kvheads)
|
||||
v = rearrange(v, "B L (H D) -> B H L D", H=self.kvheads)
|
||||
q, k = self.qknorm(q, k)
|
||||
if freqs is not None:
|
||||
q, k = apply_rope(q, k, freqs)
|
||||
if self.kvheads != self.heads:
|
||||
rep = self.heads // self.kvheads
|
||||
k = k.repeat_interleave(rep, dim=1)
|
||||
v = v.repeat_interleave(rep, dim=1)
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask=mask, skip_reshape=True,
|
||||
transformer_options=transformer_options)
|
||||
return self.wo(out * F.sigmoid(gate))
|
||||
|
||||
|
||||
class SimpleModulation(nn.Module):
|
||||
def __init__(self, dim: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.lin = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, vec):
|
||||
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device).unsqueeze(0)
|
||||
scale, shift = out.chunk(2, dim=1)
|
||||
return scale, shift
|
||||
|
||||
|
||||
class DoubleSharedModulation(nn.Module):
|
||||
def __init__(self, dim: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.lin = nn.Parameter(torch.empty(6 * dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, vec):
|
||||
out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device)
|
||||
return out.chunk(6, dim=-1)
|
||||
|
||||
|
||||
class TextFusionBlock(nn.Module):
|
||||
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
|
||||
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x, mask=None, transformer_options={}):
|
||||
x = x + self.attn(self.prenorm(x), mask=mask, transformer_options=transformer_options)
|
||||
x = x + self.mlp(self.postnorm(x))
|
||||
return x
|
||||
|
||||
|
||||
class TextFusionTransformer(nn.Module):
|
||||
def __init__(self, num_txt_layers, txt_dim, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.layerwise_blocks = nn.ModuleList([
|
||||
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(2)
|
||||
])
|
||||
self.projector = operations.Linear(num_txt_layers, 1, bias=False, device=device, dtype=dtype)
|
||||
self.refiner_blocks = nn.ModuleList([
|
||||
TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(2)
|
||||
])
|
||||
|
||||
def forward(self, x, mask=None, transformer_options={}):
|
||||
b, l, n, d = x.shape
|
||||
x = x.reshape(b * l, n, d)
|
||||
for block in self.layerwise_blocks:
|
||||
x = block(x.contiguous(), mask=None, transformer_options=transformer_options)
|
||||
x = rearrange(x, "(b l) n d -> b l d n", b=b, l=l)
|
||||
x = self.projector(x).squeeze(-1)
|
||||
for block in self.refiner_blocks:
|
||||
x = block(x, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mod = DoubleSharedModulation(features, device=device, dtype=dtype, operations=operations)
|
||||
self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations)
|
||||
self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x, vec, freqs, mask=None, transformer_options={}):
|
||||
prescale, preshift, pregate, postscale, postshift, postgate = self.mod(vec)
|
||||
x = x + pregate * self.attn((1 + prescale) * self.prenorm(x) + preshift, freqs, mask, transformer_options=transformer_options)
|
||||
x = x + postgate * self.mlp((1 + postscale) * self.postnorm(x) + postshift)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, features, patch, channels, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(features, device=device, dtype=dtype, operations=operations)
|
||||
self.linear = operations.Linear(features, patch * patch * channels, bias=True, device=device, dtype=dtype)
|
||||
self.modulation = SimpleModulation(features, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, x, tvec):
|
||||
scale, shift = self.modulation(tvec)
|
||||
x = (1 + scale) * self.norm(x) + shift
|
||||
return self.linear(x)
|
||||
|
||||
|
||||
class SingleStreamDiT(nn.Module):
|
||||
def __init__(self, features=6144, tdim=256, txtdim=2560, heads=48, kvheads=12, multiplier=4,
|
||||
layers=28, patch=2, channels=16, bias=False, theta=1e3, txtlayers=12,
|
||||
txtheads=20, txtkvheads=20, image_model=None,
|
||||
device=None, dtype=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.patch = patch
|
||||
self.channels = channels
|
||||
self.tdim = tdim
|
||||
self.heads = heads
|
||||
self.txtdim = txtdim
|
||||
self.txtlayers = txtlayers
|
||||
|
||||
headdim = features // heads
|
||||
axes = [headdim - 12 * (headdim // 16), 6 * (headdim // 16), 6 * (headdim // 16)]
|
||||
assert sum(axes) == headdim, f"axes {axes} sum != headdim {headdim}"
|
||||
self.pe_embedder = EmbedND(dim=headdim, theta=int(theta), axes_dim=axes)
|
||||
|
||||
self.first = operations.Linear(channels * patch ** 2, features, bias=True, device=device, dtype=dtype)
|
||||
self.blocks = nn.ModuleList([
|
||||
SingleStreamBlock(features, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(layers)
|
||||
])
|
||||
self.tmlp = nn.Sequential(
|
||||
operations.Linear(tdim, features, device=device, dtype=dtype),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(features, features, device=device, dtype=dtype),
|
||||
)
|
||||
self.txtfusion = TextFusionTransformer(txtlayers, txtdim, txtheads, multiplier, bias, txtkvheads,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
self.txtmlp = nn.Sequential(
|
||||
RMSNorm(txtdim, device=device, dtype=dtype, operations=operations),
|
||||
operations.Linear(txtdim, features, device=device, dtype=dtype),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(features, features, device=device, dtype=dtype),
|
||||
)
|
||||
self.last = LastLayer(features, patch, channels, device=device, dtype=dtype, operations=operations)
|
||||
self.tproj = nn.Sequential(
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(features, features * 6, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs):
|
||||
temporal = x.ndim == 5
|
||||
if temporal:
|
||||
b5, c5, t5, h5, w5 = x.shape
|
||||
x = x.reshape(b5 * t5, c5, h5, w5)
|
||||
bs, c, H_orig, W_orig = x.shape
|
||||
patch = self.patch
|
||||
# Pad the latent up to a multiple of patch (as Flux/Lumina/QwenImage do); crop back at the end.
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch, patch))
|
||||
H, W = x.shape[-2], x.shape[-1]
|
||||
h_, w_ = H // patch, W // patch
|
||||
|
||||
# context arrives as (B, seq, txtlayers*txtdim); reshape to (B, txtlayers, seq, txtdim).
|
||||
context = self._unpack_context(context)
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch, pw=patch)
|
||||
img = self.first(img)
|
||||
|
||||
t = self.tmlp(timestep_embedding(timesteps, self.tdim).unsqueeze(1).to(img.dtype))
|
||||
tvec = self.tproj(t)
|
||||
|
||||
context = self.txtfusion(context, mask=None, transformer_options=transformer_options)
|
||||
context = self.txtmlp(context)
|
||||
|
||||
txtlen, imglen = context.shape[1], img.shape[1]
|
||||
combined = torch.cat((context, img), dim=1)
|
||||
|
||||
# Position ids: text at 0, image at (0, h_idx, w_idx).
|
||||
device = combined.device
|
||||
txtpos = torch.zeros(bs, txtlen, 3, device=device, dtype=torch.float32)
|
||||
imgids = torch.zeros(h_, w_, 3, device=device, dtype=torch.float32)
|
||||
imgids[..., 1] = torch.arange(h_, device=device, dtype=torch.float32)[:, None]
|
||||
imgids[..., 2] = torch.arange(w_, device=device, dtype=torch.float32)[None, :]
|
||||
imgpos = imgids.reshape(1, h_ * w_, 3).repeat(bs, 1, 1)
|
||||
pos = torch.cat((txtpos, imgpos), dim=1)
|
||||
|
||||
freqs = self.pe_embedder(pos)
|
||||
|
||||
for block in self.blocks:
|
||||
combined = block(combined, tvec, freqs, None, transformer_options=transformer_options)
|
||||
|
||||
final = self.last(combined, t)
|
||||
out = final[:, txtlen:txtlen + imglen, :]
|
||||
out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
||||
h=h_, w=w_, ph=patch, pw=patch, c=self.channels)
|
||||
out = out[:, :, :H_orig, :W_orig] # crop padding back off
|
||||
if temporal:
|
||||
out = out.reshape(b5, t5, self.channels, H_orig, W_orig).movedim(1, 2)
|
||||
return out
|
||||
|
||||
def _unpack_context(self, context):
|
||||
# context: (B, seq, txtlayers*txtdim) -> (B, seq, txtlayers, txtdim).
|
||||
b, seq, fused = context.shape
|
||||
if fused != self.txtlayers * self.txtdim:
|
||||
raise ValueError(
|
||||
f"Krea2 expects conditioning with {self.txtlayers}x{self.txtdim}={self.txtlayers * self.txtdim} "
|
||||
f"features (a {self.txtlayers}-layer Qwen3-VL stack) but got {fused}. "
|
||||
f"Load the text encoder with CLIPLoader type 'krea2'."
|
||||
)
|
||||
return context.reshape(b, seq, self.txtlayers, self.txtdim)
|
||||
@ -1085,7 +1085,7 @@ class LTXVModel(LTXBaseModel):
|
||||
)
|
||||
|
||||
grid_mask = None
|
||||
if keyframe_idxs is not None:
|
||||
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
||||
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
||||
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
||||
@ -1330,7 +1330,7 @@ class LTXVModel(LTXBaseModel):
|
||||
x = x * (1 + scale) + shift
|
||||
x = self.proj_out(x)
|
||||
|
||||
if keyframe_idxs is not None:
|
||||
if keyframe_idxs is not None and keyframe_idxs.shape[2] > 0:
|
||||
grid_mask = kwargs["grid_mask"]
|
||||
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
||||
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
||||
|
||||
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from comfy.ldm.lightricks.model import Timesteps
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
@ -17,13 +18,11 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||
return apply_rope1(x, freqs_cis)
|
||||
|
||||
|
||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return F.silu(x) * y
|
||||
return F.silu(x, inplace=True).mul_(y)
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
|
||||
@ -51,6 +51,18 @@ class FeedForward(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Addin this back because Nunchaku custom nodes rely on it, see comment here:
|
||||
# https://github.com/Comfy-Org/ComfyUI/pull/14178#issuecomment-4640475161
|
||||
# TODO: Eventually remove this once we natively support SVDQuants
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@ -8,7 +8,7 @@ from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.math import apply_rope1, rope
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -570,6 +570,14 @@ class WanModel(torch.nn.Module):
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# In-context reference (Bernini)
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
main_len = x.shape[1]
|
||||
if context_latents is not None:
|
||||
for lat in context_latents:
|
||||
cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2)
|
||||
x = torch.cat([x, cl], dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
@ -599,6 +607,9 @@ class WanModel(torch.nn.Module):
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if context_latents is not None:
|
||||
x = x[:, :main_len]
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
@ -606,7 +617,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
@ -638,6 +649,13 @@ class WanModel(torch.nn.Module):
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
|
||||
# In-context reference: a non-zero source_id composes an extra rotation into the spatial rope
|
||||
if source_id:
|
||||
d = self.dim // self.num_heads
|
||||
pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32)
|
||||
id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype)
|
||||
freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
@ -661,6 +679,15 @@ class WanModel(torch.nn.Module):
|
||||
t_len += 1
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
# In-context reference: one rope block per stream, each with it's own source_id (1, 2, ...) to distinguish from the target (id 0).
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
if context_latents is not None:
|
||||
context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents]
|
||||
for i, lat in enumerate(context_latents):
|
||||
freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1)
|
||||
kwargs = {**kwargs, "context_latents": context_latents}
|
||||
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
@ -1631,13 +1658,15 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if ref_mask_latents is not None: # SCAIL-2 additive mask stream (one identity mask frame per reference, then video)
|
||||
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
@ -1645,6 +1674,8 @@ class SCAILWanModel(WanModel):
|
||||
scail_pose_seq_len = 0
|
||||
if pose_latents is not None:
|
||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||
if sam_latents is not None: # SCAIL-2 additive mask stream
|
||||
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
|
||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||
scail_pose_seq_len = scail_x.shape[1]
|
||||
x = torch.cat([x, scail_x], dim=1)
|
||||
@ -1695,16 +1726,44 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
|
||||
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
|
||||
# reference_latent may stack several frames: the last is the primary reference adjacent to the video, the earlier frames are additional references.
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
|
||||
if ref_mask_flag is not None and not bool(ref_mask_flag):
|
||||
REF_ROPE_H = 120.0
|
||||
POSE_ROPE_W = 120.0
|
||||
|
||||
main_t_patches = t - ref_t_patches
|
||||
video_t_start = max(ref_t_patches - 1, 0)
|
||||
|
||||
parts = []
|
||||
if ref_t_patches > 0:
|
||||
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
|
||||
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
|
||||
if main_t_patches > 0:
|
||||
parts.append(super().rope_encode(main_t_patches, h, w, t_start=video_t_start, device=device, dtype=dtype, transformer_options=transformer_options))
|
||||
|
||||
if pose_latents is not None:
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=video_t_start, device=device, dtype=dtype, transformer_options=pose_tf))
|
||||
|
||||
return torch.cat(parts, dim=1)
|
||||
|
||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||
|
||||
if pose_latents is None:
|
||||
return main_freqs
|
||||
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
|
||||
# if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames
|
||||
@ -1719,12 +1778,16 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if pose_latents is not None:
|
||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||
if ref_mask_latents is not None: # SCAIL-2
|
||||
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
|
||||
if sam_latents is not None: # SCAIL-2
|
||||
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
@ -1737,5 +1800,15 @@ class SCAILWanModel(WanModel):
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
|
||||
class SCAIL2WanModel(SCAILWanModel):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
@ -326,6 +326,17 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.Krea2):
|
||||
diffusers_keys = comfy.utils.krea2_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Lumina2):
|
||||
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
@ -357,6 +368,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
||||
import torch
|
||||
import logging
|
||||
import comfy.ldm.lightricks.av_model
|
||||
import comfy.ldm.lightricks.symmetric_patchifier
|
||||
import comfy.context_windows
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from comfy.ldm.cascade.stage_c import StageC
|
||||
@ -54,8 +55,10 @@ import comfy.ldm.pixeldit.model
|
||||
import comfy.ldm.pixeldit.pid
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.boogu.model
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.ideogram4.model
|
||||
import comfy.ldm.krea2.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
@ -65,6 +68,7 @@ import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
import comfy.ldm.hidream_o1.model
|
||||
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
|
||||
import comfy.ldm.depth_anything_3.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1202,6 +1206,127 @@ class LTXAV(BaseModel):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
def map_context_window_to_modalities(self, primary_indices, latent_shapes, dim):
|
||||
result = [primary_indices]
|
||||
if len(latent_shapes) < 2:
|
||||
return result
|
||||
|
||||
video_total = latent_shapes[0][dim]
|
||||
|
||||
for i in range(1, len(latent_shapes)):
|
||||
mod_total = latent_shapes[i][dim]
|
||||
# Map each primary index to its proportional range of modality indices and
|
||||
# concatenate in order. Preserves wrapped/strided geometry so the modality
|
||||
# attends to the same temporal regions as the primary window.
|
||||
mod_indices = []
|
||||
seen = set()
|
||||
for v_idx in primary_indices:
|
||||
a_start = min(int(round(v_idx * mod_total / video_total)), mod_total - 1)
|
||||
a_end = min(int(round((v_idx + 1) * mod_total / video_total)), mod_total)
|
||||
if a_end <= a_start:
|
||||
a_end = a_start + 1
|
||||
for a in range(a_start, a_end):
|
||||
if a not in seen:
|
||||
seen.add(a)
|
||||
mod_indices.append(a)
|
||||
result.append(mod_indices)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_guide_entries(conds):
|
||||
for cond_list in conds:
|
||||
if cond_list is None:
|
||||
continue
|
||||
for cond_dict in cond_list:
|
||||
model_conds = cond_dict.get('model_conds', {})
|
||||
entries = model_conds.get('guide_attention_entries')
|
||||
if entries is not None and hasattr(entries, 'cond') and entries.cond:
|
||||
return entries.cond
|
||||
return None
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# Audio denoise mask — slice using audio modality window
|
||||
if cond_key == "audio_denoise_mask" and hasattr(window, 'modality_windows') and window.modality_windows:
|
||||
audio_window = window.modality_windows.get(1)
|
||||
if audio_window is not None and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
sliced = audio_window.get_tensor(cond_value.cond, device, dim=2)
|
||||
return cond_value._copy_with(sliced)
|
||||
|
||||
# Video denoise mask — split into video + guide portions, slice each
|
||||
if cond_key == "denoise_mask" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
cond_tensor = cond_value.cond
|
||||
guide_count = cond_tensor.size(window.dim) - x_in.size(window.dim)
|
||||
if guide_count > 0:
|
||||
T_video = x_in.size(window.dim)
|
||||
video_mask = cond_tensor.narrow(window.dim, 0, T_video)
|
||||
guide_mask = cond_tensor.narrow(window.dim, T_video, guide_count)
|
||||
sliced_video = window.get_tensor(video_mask, device, retain_index_list=retain_index_list)
|
||||
suffix_indices = window.guide_frames_indices
|
||||
if suffix_indices:
|
||||
idx = tuple([slice(None)] * window.dim + [suffix_indices])
|
||||
sliced_guide = guide_mask[idx].to(device)
|
||||
return cond_value._copy_with(torch.cat([sliced_video, sliced_guide], dim=window.dim))
|
||||
else:
|
||||
return cond_value._copy_with(sliced_video)
|
||||
|
||||
# Keyframe indices — regenerate pixel coords for window, select guide positions
|
||||
if cond_key == "keyframe_idxs":
|
||||
kf_local_pos = window.guide_kf_local_positions
|
||||
if not kf_local_pos:
|
||||
return cond_value._copy_with(cond_value.cond[:, :, :0, :]) # empty
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
window_len = len(window.index_list)
|
||||
# account for causal_window_fix anchor in coord space size
|
||||
anchor_idx = getattr(window, 'causal_anchor_index', None)
|
||||
if anchor_idx is not None and anchor_idx >= 0:
|
||||
window_len += 1
|
||||
patchifier = self.diffusion_model.patchifier
|
||||
latent_coords = patchifier.get_latent_coords(window_len, H, W, 1, cond_value.cond.device)
|
||||
scale_factors = self.diffusion_model.vae_scale_factors
|
||||
pixel_coords = comfy.ldm.lightricks.symmetric_patchifier.latent_to_pixel_coords(
|
||||
latent_coords,
|
||||
scale_factors,
|
||||
causal_fix=self.diffusion_model.causal_temporal_positioning)
|
||||
tokens = []
|
||||
for pos in kf_local_pos:
|
||||
tokens.extend(range(pos * H * W, (pos + 1) * H * W))
|
||||
pixel_coords = pixel_coords[:, :, tokens, :]
|
||||
|
||||
# Adjust spatial end positions for dilated (downscaled) guides.
|
||||
# Each guide entry may have a different downscale factor; expand the
|
||||
# per-entry factor to cover all tokens belonging to that entry.
|
||||
downscale_factors = window.guide_downscale_factors
|
||||
overlap_info = window.guide_overlap_info
|
||||
if downscale_factors:
|
||||
per_token_factor = []
|
||||
for (entry_idx, overlap_count), dsf in zip(overlap_info, downscale_factors):
|
||||
per_token_factor.extend([dsf] * (overlap_count * H * W))
|
||||
factor_tensor = torch.tensor(per_token_factor, device=pixel_coords.device, dtype=pixel_coords.dtype)
|
||||
spatial_end_offset = (factor_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(-1) - 1) * torch.tensor(
|
||||
scale_factors[1:], device=pixel_coords.device, dtype=pixel_coords.dtype,
|
||||
).view(1, -1, 1, 1)
|
||||
pixel_coords[:, 1:, :, 1:] += spatial_end_offset
|
||||
|
||||
B = cond_value.cond.shape[0]
|
||||
if B > 1:
|
||||
pixel_coords = pixel_coords.expand(B, -1, -1, -1)
|
||||
return cond_value._copy_with(pixel_coords)
|
||||
|
||||
# Guide attention entries — adjust per-guide counts based on window overlap
|
||||
if cond_key == "guide_attention_entries":
|
||||
overlap_info = window.guide_overlap_info
|
||||
H, W = x_in.shape[3], x_in.shape[4]
|
||||
new_entries = []
|
||||
for entry_idx, overlap_count in overlap_info:
|
||||
e = cond_value.cond[entry_idx]
|
||||
new_entries.append({**e,
|
||||
"pre_filter_count": overlap_count * H * W,
|
||||
"latent_shape": [overlap_count, H, W]})
|
||||
return cond_value._copy_with(new_entries)
|
||||
|
||||
return None
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
@ -1518,8 +1643,26 @@ class WAN21(BaseModel):
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||
|
||||
# In-context reference conditioning (Bernini)
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
if context_latents is not None:
|
||||
out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents])
|
||||
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# In-context cond slicing (Bernini)
|
||||
if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list):
|
||||
dim = window.dim
|
||||
out = []
|
||||
for lat in cond_value.cond:
|
||||
if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]:
|
||||
out.append(window.get_tensor(lat, device, dim=dim, retain_index_list=retain_index_list))
|
||||
else:
|
||||
out.append(lat.to(device))
|
||||
return cond_value._copy_with(out)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
|
||||
class WAN21_CausalAR(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@ -1728,10 +1871,14 @@ class WAN21_SCAIL(WAN21):
|
||||
|
||||
reference_latents = kwargs.get("reference_latents", None)
|
||||
if reference_latents is not None:
|
||||
ref_latent = self.process_latent_in(reference_latents[-1])
|
||||
ref_mask = torch.ones_like(ref_latent[:, :4])
|
||||
ref_latent = torch.cat([ref_latent, ref_mask], dim=1)
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(ref_latent)
|
||||
# SCAIL-2 multi-reference: reference_latents[0] is the primary ref, [1:] are additional
|
||||
# references. Stack as [additional..., primary] so the primary stays adjacent to the video.
|
||||
ordered = list(reference_latents[1:]) + list(reference_latents[:1])
|
||||
stacked = []
|
||||
for lat in ordered:
|
||||
lat = self.process_latent_in(lat)
|
||||
stacked.append(torch.cat([lat, torch.ones_like(lat[:, :4])], dim=1))
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(torch.cat(stacked, dim=2))
|
||||
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
@ -1754,6 +1901,99 @@ class WAN21_SCAIL(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
class WAN21_SCAIL2(WAN21_SCAIL):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents")
|
||||
self.memory_usage_shape_process = {
|
||||
"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
|
||||
"sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
|
||||
}
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
|
||||
if driving_mask_28ch is not None:
|
||||
out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous())
|
||||
|
||||
# ref_mask_28ch holds one identity mask per stacked reference frame (additional refs first, then the primary ref), followed by zeros over the video frames.
|
||||
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
|
||||
if ref_mask_28ch is not None:
|
||||
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous())
|
||||
|
||||
ref_mask_flag = kwargs.get("ref_mask_flag", None)
|
||||
if ref_mask_flag is not None:
|
||||
out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = super().extra_conds_shapes(**kwargs)
|
||||
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
|
||||
if driving_mask_28ch is not None:
|
||||
s = driving_mask_28ch.shape
|
||||
out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]]
|
||||
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
|
||||
if ref_mask_28ch is not None:
|
||||
s = ref_mask_28ch.shape
|
||||
out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]]
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key in ("sam_latents", "pose_latents"):
|
||||
# Return sliced view omitting retain_index_list
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0)
|
||||
if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
# The ref mask is N leading ref frames padded with frames of zeros, so just grab the first frames for all windows
|
||||
full_ref_mask = cond_value.cond
|
||||
video_frame_count = x_in.shape[2]
|
||||
ref_frame_count = full_ref_mask.shape[2] - video_frame_count
|
||||
if ref_frame_count < 1:
|
||||
return None
|
||||
window_length = len(window.index_list)
|
||||
|
||||
# Account for the causal anchor frame if it exists
|
||||
anchor_index = getattr(window, "causal_anchor_index", None)
|
||||
if anchor_index is not None and anchor_index >= 0:
|
||||
window_length += 1
|
||||
|
||||
window_ref_mask = full_ref_mask[:, :, :window_length + ref_frame_count].to(device)
|
||||
return cond_value._copy_with(window_ref_mask)
|
||||
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
# The 4 extra channels are the history_mask (1 at clean-anchor frames).
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
|
||||
if extra_channels != 4:
|
||||
return super().concat_cond(**kwargs)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
return torch.zeros_like(noise)[:, :4]
|
||||
|
||||
device = kwargs["device"]
|
||||
if mask.shape[1] != 4:
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
if mask.shape[1] == 1:
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return mask
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
# Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image.
|
||||
return latent_image
|
||||
|
||||
|
||||
class WAN22_WanDancer(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)
|
||||
@ -1987,6 +2227,11 @@ class Omnigen2(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
class Boogu(Omnigen2):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(Omnigen2, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.boogu.model.BooguTransformer2DModel)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
class QwenImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.qwen_image.model.QwenImageTransformer2DModel)
|
||||
@ -2034,6 +2279,17 @@ class Ideogram4(BaseModel):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class Krea2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.krea2.model.SingleStreamDiT)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class HunyuanImage21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
@ -2227,6 +2483,12 @@ class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
|
||||
class DepthAnything3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net)
|
||||
|
||||
class ErnieImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
|
||||
|
||||
@ -630,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "humo"
|
||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "animate"
|
||||
elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail2"
|
||||
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail"
|
||||
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:
|
||||
@ -759,6 +761,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}double_stream_layers.0.img_instruct_attn.processor.img_to_q.weight'.format(key_prefix) in state_dict_keys: # Boogu-Image (OmniGen2 derivative + dual-stream stage)
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "boogu"
|
||||
dit_config["hidden_size"] = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}single_stream_layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_double_stream_layers"] = count_blocks(state_dict_keys, '{}double_stream_layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_refiner_layers"] = count_blocks(state_dict_keys, '{}noise_refiner.'.format(key_prefix) + '{}.')
|
||||
dit_config["instruction_feat_dim"] = state_dict['{}time_caption_embed.caption_embedder.0.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
if '{}time_caption_embed.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: # Omnigen2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "omnigen2"
|
||||
@ -822,6 +834,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}txtfusion.projector.weight'.format(key_prefix) in state_dict_keys: # Krea 2 (K2)
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "krea2"
|
||||
head_dim = 128
|
||||
first_w = state_dict['{}first.weight'.format(key_prefix)] # (features, channels*patch^2)
|
||||
dit_config["features"] = first_w.shape[0]
|
||||
dit_config["channels"] = first_w.shape[1] // (2 * 2) # patch=2
|
||||
dit_config["patch"] = 2
|
||||
dit_config["layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["heads"] = state_dict['{}blocks.0.attn.wq.weight'.format(key_prefix)].shape[0] // head_dim
|
||||
dit_config["kvheads"] = state_dict['{}blocks.0.attn.wk.weight'.format(key_prefix)].shape[0] // head_dim
|
||||
dit_config["txtlayers"] = state_dict['{}txtfusion.projector.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["txtdim"] = state_dict['{}txtfusion.layerwise_blocks.0.prenorm.scale'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
@ -860,6 +887,95 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
# Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py)
|
||||
if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "DepthAnything3"
|
||||
|
||||
patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)]
|
||||
embed_dim = patch_w.shape[0]
|
||||
depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.')
|
||||
|
||||
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
|
||||
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
|
||||
if backbone_name is None:
|
||||
return None
|
||||
dit_config["backbone_name"] = backbone_name
|
||||
|
||||
# Detect DA3 extensions on top of vanilla DINOv2.
|
||||
has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys
|
||||
# qk-norm shows up as `attention.q_norm.weight` on enabled blocks.
|
||||
qknorm_indices = [
|
||||
i for i in range(depth)
|
||||
if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys
|
||||
]
|
||||
qknorm_start = qknorm_indices[0] if qknorm_indices else -1
|
||||
|
||||
# The DA3 main-series configs always set alt_start == qknorm_start == rope_start.
|
||||
# cat_token=True is implied by the presence of camera_token.
|
||||
if has_camera_token:
|
||||
dit_config["alt_start"] = qknorm_start
|
||||
dit_config["rope_start"] = qknorm_start
|
||||
dit_config["qknorm_start"] = qknorm_start
|
||||
dit_config["cat_token"] = True
|
||||
else:
|
||||
dit_config["alt_start"] = -1
|
||||
dit_config["rope_start"] = -1
|
||||
dit_config["qknorm_start"] = -1
|
||||
dit_config["cat_token"] = False
|
||||
|
||||
# Detect head type and config.
|
||||
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["head_out_channels"] = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
if has_aux:
|
||||
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
|
||||
dit_config["head_type"] = "dualdpt"
|
||||
dit_config["head_output_dim"] = 2
|
||||
dit_config["head_use_sky_head"] = False
|
||||
else:
|
||||
dit_config["head_type"] = "dpt"
|
||||
dit_config["head_output_dim"] = state_dict[
|
||||
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
|
||||
].shape[0]
|
||||
dit_config["head_use_sky_head"] = (
|
||||
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
|
||||
)
|
||||
|
||||
# out_layers: hard-coded per upstream YAML config (depth-aware default).
|
||||
if depth >= 24:
|
||||
# vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT).
|
||||
if has_aux:
|
||||
dit_config["out_layers"] = [11, 15, 19, 23]
|
||||
else:
|
||||
dit_config["out_layers"] = [4, 11, 17, 23]
|
||||
else:
|
||||
# vits/vitb: 12 blocks
|
||||
dit_config["out_layers"] = [5, 7, 9, 11]
|
||||
|
||||
# Camera encoder/decoder presence (multi-view + pose path).
|
||||
has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys
|
||||
has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["has_cam_enc"] = has_cam_enc
|
||||
dit_config["has_cam_dec"] = has_cam_dec
|
||||
if has_cam_enc:
|
||||
cam_enc_w = state_dict.get(
|
||||
'{}cam_enc.pose_branch.fc2.weight'.format(key_prefix)
|
||||
)
|
||||
if cam_enc_w is not None:
|
||||
dit_config["cam_dim_out"] = cam_enc_w.shape[0]
|
||||
if has_cam_dec:
|
||||
cam_dec_w = state_dict.get(
|
||||
'{}cam_dec.fc_t.weight'.format(key_prefix)
|
||||
)
|
||||
if cam_dec_w is not None:
|
||||
dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1]
|
||||
return dit_config
|
||||
|
||||
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ernie"
|
||||
|
||||
@ -534,8 +534,10 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def set_cudnn_benchmark():
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available():
|
||||
torch.backends.cudnn.benchmark = PerformanceFeature.AutoTune in args.fast
|
||||
|
||||
try:
|
||||
if torch_version_numeric >= (2, 5):
|
||||
@ -641,6 +643,8 @@ def free_pins(size, evict_active=False):
|
||||
return freed_total
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
if args.high_ram:
|
||||
return True
|
||||
if args.fast_disk:
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
else:
|
||||
@ -651,8 +655,7 @@ def ensure_pin_budget(size, evict_active=False):
|
||||
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
||||
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
||||
|
||||
def ensure_pin_registerable(size, evict_active=True):
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
def free_registrations(shortfall, evict_active=True):
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
if shortfall <= 0:
|
||||
@ -674,6 +677,9 @@ def ensure_pin_registerable(size, evict_active=True):
|
||||
return True
|
||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||
|
||||
def ensure_pin_registerable(size, evict_active=True):
|
||||
return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active)
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
@ -956,8 +962,6 @@ def loaded_models(only_currently_used=False):
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
@ -1494,6 +1498,8 @@ if not args.disable_pinned_memory:
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def pinned_hostbuf_size(size):
|
||||
if args.high_ram:
|
||||
return max(0, int(size * 2))
|
||||
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
||||
|
||||
def discard_cuda_async_error():
|
||||
|
||||
@ -379,10 +379,11 @@ class ModelPatcher:
|
||||
def get_clone_model_override(self):
|
||||
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
|
||||
|
||||
def clone(self, disable_dynamic=False, model_override=None):
|
||||
def clone(self, disable_dynamic=False, model_override=None, force_deepcopy=False):
|
||||
class_ = self.__class__
|
||||
if self.is_dynamic() and disable_dynamic:
|
||||
class_ = ModelPatcher
|
||||
if self.is_dynamic() and disable_dynamic or force_deepcopy:
|
||||
if self.is_dynamic() and disable_dynamic:
|
||||
class_ = ModelPatcher
|
||||
if model_override is None:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
|
||||
18
comfy/ops.py
18
comfy/ops.py
@ -180,7 +180,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
||||
return
|
||||
if signature is None:
|
||||
if signature is None or args.high_ram:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||
@ -299,21 +299,21 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
if hasattr(s, "_v") and comfy.model_management.is_device_cpu(device):
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
|
||||
elif hasattr(s, "_v") and s.weight.device != device:
|
||||
prefetched = hasattr(s, "_prefetch")
|
||||
offload_stream = None
|
||||
offload_device = None
|
||||
|
||||
@ -89,13 +89,26 @@ def pin_memory(module, subset="weights", size=None):
|
||||
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
|
||||
extended = False
|
||||
try:
|
||||
hostbuf.extend(size=size)
|
||||
hostbuf.extend(size=size, register=False)
|
||||
extended = True
|
||||
pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
comfy.model_management.free_registrations(size)
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
del pin
|
||||
hostbuf.truncate(offset, do_unregister=False)
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
except RuntimeError:
|
||||
if extended:
|
||||
hostbuf.truncate(offset, do_unregister=False)
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
module._pin = pin
|
||||
stack.append((module, offset))
|
||||
module._pin_registered = True
|
||||
module._pin_stack_index = len(stack) - 1
|
||||
|
||||
31
comfy/sd.py
31
comfy/sd.py
@ -58,6 +58,7 @@ import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.krea2
|
||||
import comfy.text_encoders.ideogram4
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
@ -67,6 +68,8 @@ import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.qwen35
|
||||
import comfy.text_encoders.qwen3vl
|
||||
import comfy.text_encoders.boogu
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
@ -1302,6 +1305,8 @@ class CLIPType(Enum):
|
||||
LENS = 28
|
||||
PIXELDIT = 29
|
||||
IDEOGRAM4 = 30
|
||||
BOOGU = 31
|
||||
KREA2 = 32
|
||||
|
||||
|
||||
|
||||
@ -1355,6 +1360,8 @@ class TEModel(Enum):
|
||||
GEMMA_4_31B = 31
|
||||
T5_GEMMA = 32
|
||||
GPT_OSS_20B = 33
|
||||
QWEN3VL_4B = 34
|
||||
QWEN3VL_8B = 35
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1416,6 +1423,8 @@ def detect_te_model(sd):
|
||||
if weight.shape[0] == 5120:
|
||||
return TEModel.QWEN35_27B
|
||||
return TEModel.QWEN35_2B
|
||||
if "model.visual.deepstack_merger_list.0.norm.weight" in sd: # DeepStack is unique to Qwen3-VL
|
||||
return TEModel.QWEN3VL_4B if sd["model.visual.merger.linear_fc2.weight"].shape[0] == 2560 else TEModel.QWEN3VL_8B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
@ -1614,6 +1623,28 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
qwen35_type = {TEModel.QWEN35_08B: "qwen35_08b", TEModel.QWEN35_2B: "qwen35_2b", TEModel.QWEN35_4B: "qwen35_4b", TEModel.QWEN35_9B: "qwen35_9b", TEModel.QWEN35_27B: "qwen35_27b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen35.te(**llama_detect(clip_data), model_type=qwen35_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen35.tokenizer(model_type=qwen35_type)
|
||||
elif te_model in (TEModel.QWEN3VL_4B, TEModel.QWEN3VL_8B):
|
||||
if clip_type == CLIPType.IDEOGRAM4 and te_model == TEModel.QWEN3VL_8B: # Ideogram4 reuses the full Qwen3-VL-8B (13-layer tap for conditioning + multimodal generate).
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
clip_target.clip = comfy.text_encoders.ideogram4.te_qwen3vl(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Qwen3VLTokenizer
|
||||
elif clip_type == CLIPType.BOOGU and te_model == TEModel.QWEN3VL_8B: # Boogu-Image: full Qwen3-VL-8B, last hidden state, no-think template.
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
clip_target.clip = comfy.text_encoders.boogu.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.boogu.BooguTokenizer
|
||||
elif clip_type == CLIPType.KREA2 and te_model == TEModel.QWEN3VL_4B: # Krea2: full Qwen3-VL-4B (12-layer tap for conditioning + multimodal generate).
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
clip_target.clip = comfy.text_encoders.krea2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.krea2.Krea2Tokenizer
|
||||
elif clip_type in (CLIPType.FLUX, CLIPType.FLUX2): # Flux2 Klein reuses the Qwen3-VL LM (3-layer tap -> 12288); visual unused.
|
||||
klein_model_type = "qwen3_8b" if te_model == TEModel.QWEN3VL_8B else "qwen3_4b"
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type=klein_model_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B if te_model == TEModel.QWEN3VL_8B else comfy.text_encoders.flux.KleinTokenizer
|
||||
else:
|
||||
clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."})
|
||||
qwen3vl_type = {TEModel.QWEN3VL_4B: "qwen3vl_4b", TEModel.QWEN3VL_8B: "qwen3vl_8b"}[te_model]
|
||||
clip_target.clip = comfy.text_encoders.qwen3vl.te(**llama_detect(clip_data), model_type=qwen3vl_type)
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen3vl.tokenizer(model_type=qwen3vl_type)
|
||||
elif te_model == TEModel.QWEN3_06B:
|
||||
clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer
|
||||
|
||||
@ -25,6 +25,8 @@ import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ideogram4
|
||||
import comfy.text_encoders.boogu
|
||||
import comfy.text_encoders.krea2
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
@ -1450,6 +1452,17 @@ class WAN21_SCAIL(WAN21_T2V):
|
||||
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_SCAIL2(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "scail2",
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_WanDancer(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1747,6 +1760,27 @@ class Omnigen2(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
class Boogu(Omnigen2):
|
||||
unet_config = {
|
||||
"image_model": "boogu",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 3.16,
|
||||
}
|
||||
|
||||
memory_usage_factor = 2.15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Boogu(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.boogu.BooguTokenizer, comfy.text_encoders.boogu.te(**hunyuan_detect))
|
||||
|
||||
class Ideogram4(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ideogram4",
|
||||
@ -1785,6 +1819,35 @@ class Ideogram4(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class Krea2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "krea2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.15,
|
||||
}
|
||||
|
||||
memory_usage_factor = 2.2
|
||||
|
||||
latent_format = latent_formats.Wan21
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Krea2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_4b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.krea2.Krea2Tokenizer, comfy.text_encoders.krea2.te(**hunyuan_detect))
|
||||
|
||||
class QwenImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "qwen_image",
|
||||
@ -2046,6 +2109,23 @@ class RT_DETR_v4(supported_models_base.BASE):
|
||||
return None
|
||||
|
||||
|
||||
class DepthAnything3(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "DepthAnything3",
|
||||
}
|
||||
|
||||
# Mono path: no num_heads / num_head_channels needed.
|
||||
unet_extra_config = {}
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.DepthAnything3(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
|
||||
class ErnieImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ernie",
|
||||
@ -2260,6 +2340,7 @@ models = [
|
||||
WAN22_Animate,
|
||||
WAN21_FlowRVS,
|
||||
WAN21_SCAIL,
|
||||
WAN21_SCAIL2,
|
||||
WAN22_WanDancer,
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
@ -2272,8 +2353,10 @@ models = [
|
||||
ACEStep,
|
||||
ACEStep15,
|
||||
Omnigen2,
|
||||
Boogu,
|
||||
QwenImage,
|
||||
Ideogram4,
|
||||
Krea2,
|
||||
Flux2,
|
||||
Lens,
|
||||
Kandinsky5Image,
|
||||
@ -2287,4 +2370,5 @@ models = [
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
DepthAnything3,
|
||||
]
|
||||
|
||||
58
comfy/text_encoders/boogu.py
Normal file
58
comfy/text_encoders/boogu.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""Boogu-Image text encoder: full Qwen3-VL-8B, last hidden state (4096-dim).
|
||||
|
||||
Boogu uses the final hidden state of Qwen3-VL as the per-token instruction feature
|
||||
(num_instruction_feature_layers=1, reduce_type=mean -> just the last layer).
|
||||
The model itself is the standard Qwen3-VL TE, only the chat template differs
|
||||
(a fixed system prompt and no <think> block).
|
||||
"""
|
||||
|
||||
import comfy.text_encoders.qwen3vl
|
||||
from comfy import sd1_clip
|
||||
|
||||
|
||||
# System prompts from the reference pipeline (pipeline_boogu.py).
|
||||
# T2I (non-empty instruction, no image) uses the helpful-assistant prompt
|
||||
# everything else (the CFG negative / "drop" condition, and any image case) uses the TI2I "describe" prompt.
|
||||
BOOGU_T2I_SYSTEM = "You are a helpful assistant that generates high-quality images based on user instructions. The instructions are as follows."
|
||||
BOOGU_DROP_SYSTEM = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
|
||||
|
||||
|
||||
class BooguTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b")
|
||||
# apply_chat_template without add_generation_prompt
|
||||
self.llama_template = "<|im_start|>system\n" + BOOGU_T2I_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n"
|
||||
self.llama_template_images = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n"
|
||||
# Reference SYSTEM_PROMPT_DROP: used for the empty negative/uncond instruction.
|
||||
self.llama_template_drop = "<|im_start|>system\n" + BOOGU_DROP_SYSTEM + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
|
||||
if llama_template is None and len(images) == 0 and text.strip() == "":
|
||||
llama_template = self.llama_template_drop
|
||||
# Boogu conditions on the no-think template; thinking=True drops the empty <think> block qwen3vl adds by default.
|
||||
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
|
||||
|
||||
|
||||
class BooguQwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"):
|
||||
super().__init__(device=device, dtype=dtype, attention_mask=attention_mask, model_options=model_options, model_type=model_type)
|
||||
# apply the final RMSNorm to the tapped last layer
|
||||
self.layer_norm_hidden_state = True
|
||||
|
||||
|
||||
class BooguTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
clip_model = lambda **kw: BooguQwen3VLClipModel(**kw, model_type="qwen3vl_8b")
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class BooguTEModel_(BooguTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return BooguTEModel_
|
||||
@ -9,6 +9,7 @@ import os
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
import comfy.text_encoders.llama
|
||||
import comfy.text_encoders.qwen3vl
|
||||
from comfy import sd1_clip
|
||||
|
||||
# Reference taps outputs of layers (0,3,...,35); comfy captures layer inputs, offset by +1.
|
||||
@ -32,7 +33,9 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
if text.startswith('<|im_start|>'):
|
||||
llama_text = text
|
||||
elif llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
@ -75,3 +78,43 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Ideogram4TEModel_
|
||||
|
||||
|
||||
# Full Qwen3-VL-8B variant with vision
|
||||
|
||||
class Ideogram4Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=IDEOGRAM4_TAP_LAYERS, layer_idx=None, dtype=dtype,
|
||||
attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_8b")
|
||||
|
||||
|
||||
class Ideogram4Qwen3VLTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Ideogram4Qwen3VLClipModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096), ascending layer order.
|
||||
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13 = 53248).
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
class Ideogram4Qwen3VLTokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_8b")
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
|
||||
# Ideogram 4 conditions on the no-think template; default thinking=True drops the empty think block qwen3vl adds.
|
||||
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
|
||||
|
||||
|
||||
def te_qwen3vl(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Ideogram4Qwen3VLTEModel_(Ideogram4Qwen3VLTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Ideogram4Qwen3VLTEModel_
|
||||
|
||||
84
comfy/text_encoders/krea2.py
Normal file
84
comfy/text_encoders/krea2.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""Krea 2 (K2) text encoder: Qwen3-VL-4B, 12-layer tap.
|
||||
|
||||
K2 conditions on a stack of hidden states from 12 layers of Qwen3-VL-4B
|
||||
(reference taps ``hidden_states[2,5,8,...,35]``), kept as a ``(B, 12, seq, 2560)`` tensor and
|
||||
consumed by the DiT's internal ``txtfusion`` adapter. Comfy carries conditioning as a 3D tensor,
|
||||
so the 12-layer stack is flattened to ``(B, seq, 12*2560)`` here and unpacked inside the model.
|
||||
"""
|
||||
|
||||
import numbers
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.text_encoders.qwen3vl
|
||||
from comfy import sd1_clip
|
||||
|
||||
# tap k == hidden_states[k] (no offset).
|
||||
KREA2_TAP_LAYERS = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35]
|
||||
|
||||
# Identical system template to Qwen-Image; Krea2 strips the system+user-opening prefix.
|
||||
KREA2_TEMPLATE = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
|
||||
class Krea2Tokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_4b")
|
||||
self.llama_template = KREA2_TEMPLATE # conditioning template; image text-gen uses qwen3vl's default image template.
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs):
|
||||
# Krea2 conditions on the no-think template; thinking=True drops the empty <think> block qwen3vl adds.
|
||||
return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs)
|
||||
|
||||
|
||||
class Krea2Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=KREA2_TAP_LAYERS, layer_idx=None, dtype=dtype,
|
||||
attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_4b")
|
||||
|
||||
|
||||
class Krea2TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3vl_4b", clip_model=Krea2Qwen3VLClipModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs) # out: (B, 12, seq, 2560)
|
||||
tok_pairs = token_weight_pairs["qwen3vl_4b"][0]
|
||||
|
||||
# Strip the system + user-opening prefix
|
||||
count_im_start = 0
|
||||
if template_end == -1:
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral):
|
||||
if elem == 151644 and count_im_start < 2:
|
||||
template_end = i
|
||||
count_im_start += 1
|
||||
if out.shape[2] > (template_end + 3):
|
||||
if tok_pairs[template_end + 1][0] == 872: # "user"
|
||||
if tok_pairs[template_end + 2][0] == 198: # "\n"
|
||||
template_end += 3
|
||||
|
||||
out = out[:, :, template_end:]
|
||||
|
||||
b, n, seq, h = out.shape
|
||||
# Flatten the 12-layer axis into the feature dim: (B, seq, 12*2560). Unpacked in the model.
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, seq, n * h)
|
||||
|
||||
if "attention_mask" in extra:
|
||||
extra["attention_mask"] = extra["attention_mask"][:, template_end:]
|
||||
if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]):
|
||||
extra.pop("attention_mask")
|
||||
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Krea2TEModel_(Krea2TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Krea2TEModel_
|
||||
@ -251,6 +251,19 @@ class Qwen3_8BConfig:
|
||||
lm_head: bool = True
|
||||
stop_tokens = [151643, 151645]
|
||||
|
||||
@dataclass
|
||||
class Qwen3VL_8BConfig(Qwen3_8BConfig):
|
||||
max_position_embeddings: int = 262144
|
||||
rope_theta: float = 5000000.0
|
||||
rope_dims = [24, 20, 20]
|
||||
interleaved_mrope = True
|
||||
|
||||
@dataclass
|
||||
class Qwen3VL_4BConfig(Qwen3VL_8BConfig):
|
||||
hidden_size: int = 2560
|
||||
intermediate_size: int = 9728
|
||||
lm_head: bool = False # 4B ties word embeddings
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
vocab_size: int = 151936
|
||||
@ -703,7 +716,8 @@ class Llama2_(nn.Module):
|
||||
interleaved_mrope=getattr(self.config, "interleaved_mrope", False),
|
||||
device=device)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True,
|
||||
dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None,deepstack_embeds=None, visual_pos_masks=None):
|
||||
if embeds is not None:
|
||||
x = embeds
|
||||
else:
|
||||
@ -767,6 +781,10 @@ class Llama2_(nn.Module):
|
||||
if current_kv is not None:
|
||||
next_key_values.append(current_kv)
|
||||
|
||||
# DeepStack: add per-layer visual features into the first len() decoder layers at image positions (Qwen3-VL)
|
||||
if deepstack_embeds is not None and i < len(deepstack_embeds):
|
||||
x[visual_pos_masks] = x[visual_pos_masks] + deepstack_embeds[i].to(x)
|
||||
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
@ -860,7 +878,7 @@ class BaseGenerate:
|
||||
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
|
||||
return past_key_values
|
||||
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
|
||||
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None, position_ids=None, deepstack_embeds=None, visual_pos_masks=None):
|
||||
device = embeds.device
|
||||
|
||||
if stop_tokens is None:
|
||||
@ -884,10 +902,18 @@ class BaseGenerate:
|
||||
generated_token_ids = []
|
||||
pbar = comfy.utils.ProgressBar(max_length)
|
||||
|
||||
# MRoPE: prefill uses explicit 3D position_ids, decode continues from the last position
|
||||
next_pos = int(position_ids[:, -1].max()) + 1 if position_ids is not None else None
|
||||
|
||||
# Generation loop
|
||||
current_input_ids = initial_input_ids
|
||||
for step in tqdm(range(max_length), desc="Generating tokens"):
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
|
||||
# DeepStack visual features are injected on the prefill only; gemma4's forward lacks these kwargs.
|
||||
extra = {}
|
||||
if step == 0 and deepstack_embeds is not None:
|
||||
extra["deepstack_embeds"] = deepstack_embeds
|
||||
extra["visual_pos_masks"] = visual_pos_masks
|
||||
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids, position_ids=position_ids, **extra)
|
||||
logits = self.logits(x)[:, -1]
|
||||
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
|
||||
token_id = next_token[0].item()
|
||||
@ -895,6 +921,9 @@ class BaseGenerate:
|
||||
|
||||
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
|
||||
current_input_ids = next_token if initial_input_ids is not None else None
|
||||
if next_pos is not None: # advance MRoPE position for the next (decode) step
|
||||
position_ids = torch.tensor([[next_pos]], device=device)
|
||||
next_pos += 1
|
||||
pbar.update(1)
|
||||
|
||||
if token_id in stop_tokens:
|
||||
|
||||
@ -3,7 +3,6 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
import math
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
@ -563,6 +562,8 @@ class Qwen35VisionModel(nn.Module):
|
||||
for _ in range(config["depth"])
|
||||
])
|
||||
self.merger = Qwen35VisionPatchMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||
self.deepstack_visual_indexes = [] # DeepStack, per-layer visual features (Qwen3-VL)
|
||||
self.deepstack_merger_list = None
|
||||
|
||||
def rot_pos_emb(self, grid_thw):
|
||||
merge_size = self.spatial_merge_size
|
||||
@ -664,9 +665,14 @@ class Qwen35VisionModel(nn.Module):
|
||||
).cumsum(dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||
for blk in self.blocks:
|
||||
deepstack_features = []
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
x = blk(x, cu_seqlens=cu_seqlens, position_embeddings=position_embeddings, optimized_attention=optimized_attention)
|
||||
if self.deepstack_merger_list is not None and layer_num in self.deepstack_visual_indexes:
|
||||
deepstack_features.append(self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](x))
|
||||
merged = self.merger(x)
|
||||
if self.deepstack_merger_list is not None:
|
||||
return merged, deepstack_features
|
||||
return merged
|
||||
|
||||
# Model Wrapper
|
||||
@ -690,30 +696,7 @@ class Qwen35(BaseLlama, BaseGenerate, torch.nn.Module):
|
||||
return None, None
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[], past_key_values=None):
|
||||
grid = None
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
grid = e.get("extra", None)
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
|
||||
if grid is None:
|
||||
position_ids = None
|
||||
|
||||
position_ids = comfy.text_encoders.qwen_vl.qwen2vl_mrope_position_ids(embeds_info, embeds.shape[1], embeds.device)
|
||||
return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids, past_key_values=past_key_values)
|
||||
|
||||
def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
|
||||
|
||||
193
comfy/text_encoders/qwen3vl.py
Normal file
193
comfy/text_encoders/qwen3vl.py
Normal file
@ -0,0 +1,193 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.qwen_vl
|
||||
from .qwen35 import Qwen35VisionModel
|
||||
from .llama import BaseLlama, BaseQwen3, BaseGenerate, Llama2_, Qwen3VL_4BConfig, Qwen3VL_8BConfig
|
||||
|
||||
|
||||
QWEN3VL_VISION = {
|
||||
"qwen3vl_4b": dict(hidden_size=1024, intermediate_size=4096, depth=24, deepstack_visual_indexes=[5, 11, 17]),
|
||||
"qwen3vl_8b": dict(hidden_size=1152, intermediate_size=4304, depth=27, deepstack_visual_indexes=[8, 16, 24]),
|
||||
}
|
||||
QWEN3VL_VISION_COMMON = dict(num_heads=16, patch_size=16, temporal_patch_size=2, in_channels=3,
|
||||
spatial_merge_size=2, num_position_embeddings=2304)
|
||||
|
||||
QWEN3VL_CONFIGS = {"qwen3vl_4b": Qwen3VL_4BConfig, "qwen3vl_8b": Qwen3VL_8BConfig}
|
||||
|
||||
|
||||
class Qwen3VLDeepstackMerger(nn.Module):
|
||||
# DeepStack merger: postshuffle LayerNorm (applied after spatial merge), unlike the main merger.
|
||||
def __init__(self, hidden_size, spatial_merge_size, out_hidden_size, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.merge_dim = hidden_size * (spatial_merge_size ** 2)
|
||||
self.norm = ops.LayerNorm(self.merge_dim, eps=1e-6, device=device, dtype=dtype)
|
||||
self.linear_fc1 = ops.Linear(self.merge_dim, self.merge_dim, device=device, dtype=dtype)
|
||||
self.linear_fc2 = ops.Linear(self.merge_dim, out_hidden_size, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x.view(-1, self.merge_dim))
|
||||
return self.linear_fc2(F.gelu(self.linear_fc1(x)))
|
||||
|
||||
|
||||
class Qwen3VLVisionModel(Qwen35VisionModel):
|
||||
# Qwen3.5 vision + DeepStack
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__(config, device=device, dtype=dtype, ops=ops)
|
||||
self.deepstack_visual_indexes = config["deepstack_visual_indexes"]
|
||||
self.deepstack_merger_list = nn.ModuleList([
|
||||
Qwen3VLDeepstackMerger(self.hidden_size, self.spatial_merge_size, config["out_hidden_size"], device=device, dtype=dtype, ops=ops)
|
||||
for _ in self.deepstack_visual_indexes
|
||||
])
|
||||
|
||||
|
||||
class Qwen3VL(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
|
||||
model_type = "qwen3vl_8b"
|
||||
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = QWEN3VL_CONFIGS[self.model_type](**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
vision_config = {**QWEN3VL_VISION_COMMON, **QWEN3VL_VISION[self.model_type], "out_hidden_size": config.hidden_size}
|
||||
self.visual = Qwen3VLVisionModel(vision_config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def preprocess_embed(self, embed, device):
|
||||
if embed["type"] == "image":
|
||||
# Qwen3-VL normalizes to [-1, 1] (mean/std 0.5), unlike Qwen2.5-VL's CLIP normalization.
|
||||
image, grid = comfy.text_encoders.qwen_vl.process_qwen2vl_images(embed["data"], patch_size=16, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])
|
||||
merged, deepstack = self.visual(image.to(device, dtype=torch.float32), grid)
|
||||
return merged, {"grid": grid, "deepstack": deepstack}
|
||||
return None, None
|
||||
|
||||
def build_image_inputs(self, embeds, embeds_info):
|
||||
# Returns (position_ids, visual_pos_masks, deepstack) for the prompt
|
||||
images = sorted([e for e in embeds_info if e.get("type") == "image"], key=lambda e: e["index"])
|
||||
if len(images) == 0:
|
||||
return None, None, None
|
||||
|
||||
device = embeds.device
|
||||
seq = embeds.shape[1]
|
||||
position_ids = comfy.text_encoders.qwen_vl.qwen2vl_mrope_position_ids(embeds_info, seq, device)
|
||||
|
||||
# DeepStack: mask of image positions + per-vision-layer features to inject there.
|
||||
visual_pos_masks = torch.zeros((1, seq), dtype=torch.bool, device=device)
|
||||
deepstack = None
|
||||
for e in images:
|
||||
start = e["index"]
|
||||
end = e["size"] + start
|
||||
visual_pos_masks[0, start:end] = True
|
||||
ds = e["extra"]["deepstack"]
|
||||
if deepstack is None:
|
||||
deepstack = [d for d in ds]
|
||||
else:
|
||||
deepstack = [torch.cat([deepstack[i], ds[i]], dim=0) for i in range(len(ds))]
|
||||
return position_ids, visual_pos_masks, deepstack
|
||||
|
||||
|
||||
def _make_qwen3vl_model(model_type):
|
||||
class Qwen3VL_(Qwen3VL):
|
||||
pass
|
||||
Qwen3VL_.model_type = model_type
|
||||
return Qwen3VL_
|
||||
|
||||
|
||||
class Qwen3VLClipModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}, model_type="qwen3vl_8b"):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={},
|
||||
dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False,
|
||||
model_class=_make_qwen3vl_model(model_type), enable_attention_masks=attention_mask,
|
||||
return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty=0.0):
|
||||
if isinstance(tokens, dict):
|
||||
tokens = next(iter(tokens.values()))
|
||||
tokens_only = [[t[0] for t in b] for b in tokens]
|
||||
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
|
||||
position_ids, visual_pos_masks, deepstack = self.transformer.build_image_inputs(embeds, embeds_info)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed,
|
||||
presence_penalty=presence_penalty, position_ids=position_ids,
|
||||
visual_pos_masks=visual_pos_masks, deepstack_embeds=deepstack)
|
||||
|
||||
|
||||
class Qwen3VLTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, model_type="qwen3vl_8b"):
|
||||
clip_model = lambda **kw: Qwen3VLClipModel(**kw, model_type=model_type)
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
class Qwen3VLSDTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, embedding_size=4096, embedding_key="qwen3vl_8b"):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Qwen3VLTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, model_type="qwen3vl_8b"):
|
||||
embedding_size = 2560 if model_type == "qwen3vl_4b" else 4096
|
||||
tokenizer = lambda *a, **kw: Qwen3VLSDTokenizer(*a, **kw, embedding_size=embedding_size, embedding_key=model_type)
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=model_type, tokenizer=tokenizer)
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=False, **kwargs):
|
||||
image = kwargs.get("image", None)
|
||||
if image is not None and len(images) == 0:
|
||||
images = [image[i:i + 1] for i in range(image.shape[0])]
|
||||
|
||||
skip_template = text.startswith('<|im_start|>')
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
else:
|
||||
if llama_template is not None:
|
||||
template = llama_template
|
||||
elif len(images) == 0:
|
||||
template = self.llama_template
|
||||
else:
|
||||
template = self.llama_template_images
|
||||
if len(images) > 1:
|
||||
vision_block = "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
template = template.replace(vision_block, vision_block * len(images), 1)
|
||||
llama_text = template.format(text)
|
||||
if not thinking: # Qwen3 convention: empty think block suppresses reasoning
|
||||
llama_text += "<think>\n\n</think>\n\n"
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
key_name = next(iter(tokens))
|
||||
embed_count = 0
|
||||
for r in tokens[key_name]:
|
||||
for i in range(len(r)):
|
||||
if r[i][0] == 151655: # <|image_pad|>
|
||||
if len(images) > embed_count:
|
||||
r[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"},) + r[i][1:]
|
||||
embed_count += 1
|
||||
return tokens
|
||||
|
||||
|
||||
def tokenizer(model_type="qwen3vl_8b"):
|
||||
class Qwen3VLTokenizer_(Qwen3VLTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type=model_type)
|
||||
return Qwen3VLTokenizer_
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3vl_8b"):
|
||||
class Qwen3VLTEModel_(Qwen3VLTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, model_type=model_type)
|
||||
return Qwen3VLTEModel_
|
||||
@ -88,6 +88,32 @@ def process_qwen2vl_images(
|
||||
return flatten_patches, image_grid_thw
|
||||
|
||||
|
||||
def qwen2vl_mrope_position_ids(embeds_info, seq_len, device):
|
||||
# (3, seq_len) T/H/W MRoPE position ids: text runs sequentially, each image span gets its grid positions.
|
||||
# Returns None when there are no image embeds. `extra` is the image grid_thw, or a dict carrying it under "grid".
|
||||
position_ids = None
|
||||
offset = 0
|
||||
for e in embeds_info:
|
||||
if e.get("type") == "image":
|
||||
extra = e.get("extra", None)
|
||||
grid = extra["grid"] if isinstance(extra, dict) else extra
|
||||
start = e.get("index")
|
||||
if position_ids is None:
|
||||
position_ids = torch.zeros((3, seq_len), device=device)
|
||||
position_ids[:, :start] = torch.arange(0, start, device=device)
|
||||
end = e.get("size") + start
|
||||
len_max = int(grid.max()) // 2
|
||||
start_next = len_max + start
|
||||
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (seq_len - end) + offset, device=device)
|
||||
position_ids[0, start:end] = start + offset
|
||||
max_d = int(grid[0][1]) // 2
|
||||
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]
|
||||
max_d = int(grid[0][2]) // 2
|
||||
position_ids[2, start:end] = torch.arange(start + offset, start + max_d + offset, device=device).unsqueeze(0).repeat(math.ceil((end - start) / max_d), 1).flatten(0)[:end - start]
|
||||
offset += len_max - (end - start)
|
||||
return position_ids
|
||||
|
||||
|
||||
class VisionPatchEmbed(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@ -818,6 +818,44 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""):
|
||||
|
||||
return key_map
|
||||
|
||||
def krea2_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_layers = mmdit_config.get("layers", 0)
|
||||
n_txt_layerwise = 2 # TextFusionTransformer hardcodes 2 layerwise + 2 refiner blocks
|
||||
n_txt_refiner = 2
|
||||
key_map = {}
|
||||
|
||||
def add_block(prefix_to, prefix_from):
|
||||
block_map = {
|
||||
"attn.to_q": "attn.wq", "attn.to_k": "attn.wk", "attn.to_v": "attn.wv",
|
||||
"attn.to_gate": "attn.gate", "attn.to_out.0": "attn.wo",
|
||||
"attn.to_out": "attn.wo", # some tools drop the ".0" on to_out
|
||||
"ff.gate": "mlp.gate", "ff.up": "mlp.up", "ff.down": "mlp.down",
|
||||
}
|
||||
for d, c in block_map.items():
|
||||
key_map["{}.{}.weight".format(prefix_to, d)] = "{}{}.{}.weight".format(output_prefix, prefix_from, c)
|
||||
|
||||
for i in range(n_layers):
|
||||
add_block("transformer_blocks.{}".format(i), "blocks.{}".format(i))
|
||||
for i in range(n_txt_layerwise):
|
||||
add_block("text_fusion.layerwise_blocks.{}".format(i), "txtfusion.layerwise_blocks.{}".format(i))
|
||||
for i in range(n_txt_refiner):
|
||||
add_block("text_fusion.refiner_blocks.{}".format(i), "txtfusion.refiner_blocks.{}".format(i))
|
||||
|
||||
MAP_BASIC = [
|
||||
("img_in", "first"),
|
||||
("time_embed.linear_1", "tmlp.0"),
|
||||
("time_embed.linear_2", "tmlp.2"),
|
||||
("time_mod_proj", "tproj.1"),
|
||||
("txt_in.linear_1", "txtmlp.1"),
|
||||
("txt_in.linear_2", "txtmlp.3"),
|
||||
("text_fusion.projector", "txtfusion.projector"),
|
||||
("final_layer.linear", "last.linear"),
|
||||
]
|
||||
for d, c in MAP_BASIC:
|
||||
key_map["{}.weight".format(d)] = "{}{}.weight".format(output_prefix, c)
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
|
||||
@ -25,6 +25,11 @@ CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = {
|
||||
"default": False,
|
||||
"description": "Show the sign-in button in the frontend even when not signed in",
|
||||
},
|
||||
"enable_telemetry": {
|
||||
"type": "bool",
|
||||
"default": False,
|
||||
"description": "Signal the frontend that telemetry collection is enabled",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -27,10 +27,13 @@ class VideoInput(ABC):
|
||||
path: Union[str, IO[bytes]],
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
|
||||
bit_depth selects the encoded bit depth; None keeps the video's native depth.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -83,6 +86,14 @@ class VideoInput(ABC):
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
"""
|
||||
Returns the bit depth of the video (e.g. 8 or 10).
|
||||
|
||||
Default implementation returns 8; subclasses report their real depth.
|
||||
"""
|
||||
return 8
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
@ -52,6 +52,12 @@ def get_open_write_kwargs(
|
||||
return open_kwargs
|
||||
|
||||
|
||||
def video_stream_bit_depth(stream) -> int:
|
||||
if stream is None or stream.format is None or not stream.format.components:
|
||||
return 8
|
||||
return max(component.bits for component in stream.format.components)
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
@ -97,6 +103,13 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
return video_stream_bit_depth(video_stream)
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
@ -257,6 +270,7 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
image_format = 'gbrpf32le'
|
||||
process_image_format = lambda a: a
|
||||
align_graph = None
|
||||
audio = None
|
||||
|
||||
streams = [video_stream]
|
||||
@ -310,7 +324,28 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
checked_alpha = True
|
||||
|
||||
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
||||
# Fix non-deterministic video decode when the video width is not a multiple of 32
|
||||
# For non-yuvj pixel formats: most H.264/H.265 video and static images (e.g. lossy WebP via LoadImage)
|
||||
# Pad both axes to a multiple of 32 and smear the border so the alignment padding never bleeds into the cropped edges
|
||||
if image_format in ('gbrpf32le', 'gbrapf32le') and frame.width % 32 != 0:
|
||||
if align_graph is None:
|
||||
pad_w = ((frame.width + 31) // 32) * 32
|
||||
pad_h = ((frame.height + 31) // 32) * 32
|
||||
g = av.filter.Graph()
|
||||
g_src = g.add_buffer(width=frame.width, height=frame.height,
|
||||
format=frame.format.name, time_base=video_stream.time_base)
|
||||
g_pad = g.add('pad', f'{pad_w}:{pad_h}:0:0')
|
||||
g_fill = g.add('fillborders', f'left=0:right={pad_w - frame.width}:top=0:bottom={pad_h - frame.height}:mode=smear')
|
||||
g_sink = g.add('buffersink')
|
||||
g_src.link_to(g_pad)
|
||||
g_pad.link_to(g_fill)
|
||||
g_fill.link_to(g_sink)
|
||||
g.configure()
|
||||
align_graph = (g, g_src, g_sink)
|
||||
align_graph[1].push(frame)
|
||||
img = np.ascontiguousarray(align_graph[2].pull().to_ndarray(format=image_format)[:frame.height, :frame.width])
|
||||
else:
|
||||
img = frame.to_ndarray(format=image_format)
|
||||
if frame.rotation != 0:
|
||||
k = int(round(frame.rotation // 90))
|
||||
img = np.rot90(img, k=k, axes=(0, 1)).copy()
|
||||
@ -377,25 +412,32 @@ class VideoFromFile(VideoInput):
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
video_encoding = video_stream.codec.name if video_stream is not None else None
|
||||
source_bit_depth = video_stream_bit_depth(video_stream)
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
if bit_depth is None:
|
||||
bit_depth = source_bit_depth
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -451,8 +493,10 @@ class VideoFromComponents(VideoInput):
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
def __init__(self, components: VideoComponents, bit_depth: int = 8):
|
||||
self.__components = components
|
||||
# Tensor components have no inherent bit depth; this is the depth used when encoding.
|
||||
self.__bit_depth = bit_depth
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
@ -461,18 +505,26 @@ class VideoFromComponents(VideoInput):
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
return self.__bit_depth
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
# None means "use the depth this video was created with" (CreateVideo's choice).
|
||||
if bit_depth is None:
|
||||
bit_depth = self.__bit_depth
|
||||
is_10bit = bit_depth >= 10
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
@ -488,10 +540,11 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
video_stream.pix_fmt = pix_fmt
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
@ -505,9 +558,14 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
if is_10bit:
|
||||
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
|
||||
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
|
||||
else:
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format=pix_fmt)
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
|
||||
@ -755,6 +755,18 @@ class File3DKSPLAT(ComfyTypeIO):
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_SPLAT_ANY")
|
||||
class File3DSplatAny(ComfyTypeIO):
|
||||
"""General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat)."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY")
|
||||
class File3DPointCloudAny(ComfyTypeIO):
|
||||
"""General point cloud file type - accepts any supported point cloud container (currently .ply)."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
@ -1388,7 +1400,8 @@ class V3Data(TypedDict):
|
||||
class HiddenHolder:
|
||||
def __init__(self, unique_id: str, prompt: Any,
|
||||
extra_pnginfo: Any, dynprompt: Any,
|
||||
auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs):
|
||||
auth_token_comfy_org: str, api_key_comfy_org: str,
|
||||
comfy_usage_source: str = None, **kwargs):
|
||||
self.unique_id = unique_id
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
self.prompt = prompt
|
||||
@ -1401,6 +1414,8 @@ class HiddenHolder:
|
||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||
self.api_key_comfy_org = api_key_comfy_org
|
||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||
self.comfy_usage_source = comfy_usage_source
|
||||
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
'''If hidden variable not found, return None.'''
|
||||
@ -1417,6 +1432,7 @@ class HiddenHolder:
|
||||
dynprompt=d.get(Hidden.dynprompt, None),
|
||||
auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None),
|
||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||
comfy_usage_source=d.get(Hidden.comfy_usage_source, None),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1439,6 +1455,8 @@ class Hidden(str, Enum):
|
||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||
api_key_comfy_org = "API_KEY_COMFY_ORG"
|
||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||
comfy_usage_source = "COMFY_USAGE_SOURCE"
|
||||
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1642,6 +1660,8 @@ class Schema:
|
||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||
if Hidden.api_key_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.api_key_comfy_org)
|
||||
if Hidden.comfy_usage_source not in self.hidden:
|
||||
self.hidden.append(Hidden.comfy_usage_source)
|
||||
# if is an output_node, will need prompt and extra_pnginfo
|
||||
if self.is_output_node:
|
||||
if Hidden.prompt not in self.hidden:
|
||||
@ -2336,6 +2356,8 @@ __all__ = [
|
||||
"File3DSPLAT",
|
||||
"File3DSPZ",
|
||||
"File3DKSPLAT",
|
||||
"File3DSplatAny",
|
||||
"File3DPointCloudAny",
|
||||
"Hooks",
|
||||
"HookKeyframes",
|
||||
"TimestepsRange",
|
||||
|
||||
@ -285,7 +285,7 @@ class AudioSaveHelper:
|
||||
results = []
|
||||
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
file = f"{filename_with_batch_num}_{counter:05}.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Use original sample rate initially
|
||||
|
||||
9
comfy_api_nodes/apis/__init__.py
generated
9
comfy_api_nodes/apis/__init__.py
generated
@ -1310,13 +1310,6 @@ class KlingTaskStatus(str, Enum):
|
||||
failed = 'failed'
|
||||
|
||||
|
||||
class KlingTextToVideoModelName(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_6 = 'kling-v1-6'
|
||||
kling_v2_1_master = 'kling-v2-1-master'
|
||||
kling_v2_5_turbo = 'kling-v2-5-turbo'
|
||||
|
||||
|
||||
class KlingVideoGenAspectRatio(str, Enum):
|
||||
field_16_9 = '16:9'
|
||||
field_9_16 = '9:16'
|
||||
@ -5179,7 +5172,7 @@ class KlingText2VideoRequest(BaseModel):
|
||||
duration: Optional[KlingVideoGenDuration] = '5'
|
||||
external_task_id: Optional[str] = Field(None, description='Customized Task ID')
|
||||
mode: Optional[KlingVideoGenMode] = 'std'
|
||||
model_name: Optional[KlingTextToVideoModelName] = 'kling-v1'
|
||||
model_name: Optional[str] = 'kling-v1'
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative text prompt', max_length=2500
|
||||
)
|
||||
|
||||
@ -43,6 +43,7 @@ class BFLFluxEraseRequest(BaseModel):
|
||||
"white (255) marks areas to remove, black (0) marks areas to preserve.",
|
||||
)
|
||||
dilate_pixels: int = Field(10)
|
||||
seed: int | None = Field(None)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
|
||||
@ -97,3 +97,28 @@ class BriaRemoveVideoBackgroundResult(BaseModel):
|
||||
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
||||
|
||||
|
||||
class BriaVideoGreenScreenRequest(BaseModel):
|
||||
video: str = Field(..., description="Publicly accessible URL of the input video.")
|
||||
green_shade: str = Field(
|
||||
default="broadcast_green",
|
||||
description="Solid chroma-key shade applied behind the foreground "
|
||||
"(broadcast_green, chroma_green, or blue_screen).",
|
||||
)
|
||||
output_container_and_codec: str = Field(...)
|
||||
preserve_audio: bool = Field(True)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class BriaVideoReplaceBackgroundRequest(BaseModel):
|
||||
video: str = Field(..., description="Publicly accessible URL of the input (foreground) video.")
|
||||
background_url: str = Field(
|
||||
...,
|
||||
description="Publicly accessible URL of the background image or video to composite behind "
|
||||
"the foreground. Stretched to the foreground frame; match its aspect ratio for "
|
||||
"undistorted results.",
|
||||
)
|
||||
output_container_and_codec: str = Field(...)
|
||||
preserve_audio: bool = Field(True)
|
||||
seed: int = Field(...)
|
||||
|
||||
@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel):
|
||||
startOffset: GeminiOffset | None = Field(None)
|
||||
|
||||
|
||||
class GeminiThinkingConfig(BaseModel):
|
||||
includeThoughts: bool | None = Field(None)
|
||||
thinkingLevel: str = Field(...)
|
||||
|
||||
|
||||
class GeminiGenerationConfig(BaseModel):
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=65536)
|
||||
seed: int | None = Field(None)
|
||||
stopSequences: list[str] | None = Field(None)
|
||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(None, ge=1)
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
thinkingConfig: GeminiThinkingConfig | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageOutputOptions(BaseModel):
|
||||
@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel):
|
||||
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
||||
|
||||
|
||||
class GeminiThinkingConfig(BaseModel):
|
||||
includeThoughts: bool | None = Field(None)
|
||||
thinkingLevel: str = Field(...)
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
responseModalities: list[str] | None = Field(None)
|
||||
imageConfig: GeminiImageConfig | None = Field(None)
|
||||
|
||||
@ -149,3 +149,59 @@ class MotionControlRequest(BaseModel):
|
||||
character_orientation: str = Field(...)
|
||||
mode: str = Field(..., description="'pro' or 'std'")
|
||||
model_name: str = Field(...)
|
||||
|
||||
|
||||
class Kling3TurboSettings(BaseModel):
|
||||
resolution: str = Field("720p", description="'720p' or '1080p'")
|
||||
aspect_ratio: str | None = Field(None, description="'16:9'/'9:16'/'1:1'; text-to-video only")
|
||||
duration: int = Field(5, description="3-15 second")
|
||||
|
||||
|
||||
class Kling3TurboText2VideoRequest(BaseModel):
|
||||
prompt: str = Field(..., description="<=3072 chars; may use multi-shot 'shot n, m, words; ...'")
|
||||
settings: Kling3TurboSettings | None = Field(None)
|
||||
|
||||
|
||||
class Kling3TurboContent(BaseModel):
|
||||
type: str = Field(..., description="'prompt' or 'first_frame'")
|
||||
text: str | None = Field(None, description="for type=prompt; <=2500 chars")
|
||||
url: str | None = Field(None, description="for type=first_frame")
|
||||
|
||||
|
||||
class Kling3TurboImage2VideoRequest(BaseModel):
|
||||
contents: list[Kling3TurboContent] = Field(..., description="prompt + first_frame materials")
|
||||
settings: Kling3TurboSettings | None = Field(None)
|
||||
|
||||
|
||||
class Kling3TurboCreateData(BaseModel):
|
||||
id: str | None = Field(None, description="Task ID")
|
||||
status: str | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
|
||||
|
||||
class Kling3TurboCreateResponse(BaseModel):
|
||||
code: int | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
request_id: str | None = Field(None)
|
||||
data: Kling3TurboCreateData | None = Field(None)
|
||||
|
||||
|
||||
class Kling3TurboOutput(BaseModel):
|
||||
type: str | None = Field(None, description="'video', 'image', 'audio', ...")
|
||||
id: str | None = Field(None)
|
||||
url: str | None = Field(None)
|
||||
duration: str | None = Field(None)
|
||||
|
||||
|
||||
class Kling3TurboTaskData(BaseModel):
|
||||
id: str | None = Field(None)
|
||||
status: str | None = Field(None, description="submitted | processing | succeeded | failed")
|
||||
message: str | None = Field(None)
|
||||
outputs: list[Kling3TurboOutput] | None = Field(None)
|
||||
|
||||
|
||||
class Kling3TurboQueryResponse(BaseModel):
|
||||
code: int | None = Field(None)
|
||||
message: str | None = Field(None)
|
||||
request_id: str | None = Field(None)
|
||||
data: list[Kling3TurboTaskData] | None = Field(None)
|
||||
|
||||
@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, confloat
|
||||
class LumaIO:
|
||||
LUMA_REF = "LUMA_REF"
|
||||
LUMA_CONCEPTS = "LUMA_CONCEPTS"
|
||||
LUMA_RAY32_KEYFRAME = "LUMA_RAY32_KEYFRAME"
|
||||
|
||||
|
||||
class LumaReference:
|
||||
@ -20,13 +21,14 @@ class LumaReference:
|
||||
def create_api_model(self, download_url: str):
|
||||
return LumaImageRef(url=download_url, weight=self.weight)
|
||||
|
||||
|
||||
class LumaReferenceChain:
|
||||
def __init__(self, first_ref: LumaReference=None):
|
||||
def __init__(self, first_ref: LumaReference = None):
|
||||
self.refs: list[LumaReference] = []
|
||||
if first_ref:
|
||||
self.refs.append(first_ref)
|
||||
|
||||
def add(self, luma_ref: LumaReference=None):
|
||||
def add(self, luma_ref: LumaReference = None):
|
||||
self.refs.append(luma_ref)
|
||||
|
||||
def create_api_model(self, download_urls: list[str], max_refs=4):
|
||||
@ -124,7 +126,7 @@ def get_luma_concepts(include_none=False):
|
||||
"pull_out",
|
||||
"aerial",
|
||||
"crane_up",
|
||||
"eye_level"
|
||||
"eye_level",
|
||||
]
|
||||
|
||||
|
||||
@ -162,8 +164,8 @@ class LumaVideoModelOutputDuration(str, Enum):
|
||||
|
||||
|
||||
class LumaGenerationType(str, Enum):
|
||||
video = 'video'
|
||||
image = 'image'
|
||||
video = "video"
|
||||
image = "image"
|
||||
|
||||
|
||||
class LumaState(str, Enum):
|
||||
@ -174,86 +176,109 @@ class LumaState(str, Enum):
|
||||
|
||||
|
||||
class LumaAssets(BaseModel):
|
||||
video: Optional[str] = Field(None, description='The URL of the video')
|
||||
image: Optional[str] = Field(None, description='The URL of the image')
|
||||
progress_video: Optional[str] = Field(None, description='The URL of the progress video')
|
||||
video: Optional[str] = Field(None, description="The URL of the video")
|
||||
image: Optional[str] = Field(None, description="The URL of the image")
|
||||
progress_video: Optional[str] = Field(None, description="The URL of the progress video")
|
||||
|
||||
|
||||
class LumaImageRef(BaseModel):
|
||||
"""Used for image gen"""
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
|
||||
url: str = Field(..., description="The URL of the image reference")
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
|
||||
|
||||
|
||||
class LumaImageReference(BaseModel):
|
||||
"""Used for video gen"""
|
||||
type: Optional[str] = Field('image', description='Input type, defaults to image')
|
||||
url: str = Field(..., description='The URL of the image')
|
||||
|
||||
type: Optional[str] = Field("image", description="Input type, defaults to image")
|
||||
url: str = Field(..., description="The URL of the image")
|
||||
|
||||
|
||||
class LumaModifyImageRef(BaseModel):
|
||||
url: str = Field(..., description='The URL of the image reference')
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
|
||||
url: str = Field(..., description="The URL of the image reference")
|
||||
weight: confloat(ge=0.0, le=1.0) = Field(..., description="The weight of the image reference")
|
||||
|
||||
|
||||
class LumaCharacterRef(BaseModel):
|
||||
identity0: LumaImageIdentity = Field(..., description='The image identity object')
|
||||
identity0: LumaImageIdentity = Field(..., description="The image identity object")
|
||||
|
||||
|
||||
class LumaImageIdentity(BaseModel):
|
||||
images: list[str] = Field(..., description='The URLs of the image identity')
|
||||
images: list[str] = Field(..., description="The URLs of the image identity")
|
||||
|
||||
|
||||
class LumaGenerationReference(BaseModel):
|
||||
type: str = Field('generation', description='Input type, defaults to generation')
|
||||
id: str = Field(..., description='The ID of the generation')
|
||||
type: str = Field("generation", description="Input type, defaults to generation")
|
||||
id: str = Field(..., description="The ID of the generation")
|
||||
|
||||
|
||||
class LumaKeyframes(BaseModel):
|
||||
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
|
||||
frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
|
||||
frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description="")
|
||||
|
||||
|
||||
class LumaConceptObject(BaseModel):
|
||||
key: str = Field(..., description='Camera Concept name')
|
||||
key: str = Field(..., description="Camera Concept name")
|
||||
|
||||
|
||||
class LumaImageGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The prompt of the generation')
|
||||
model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation')
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation')
|
||||
image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects')
|
||||
style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects')
|
||||
character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object')
|
||||
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object')
|
||||
prompt: str = Field(..., description="The prompt of the generation")
|
||||
model: LumaImageModel = Field(LumaImageModel.photon_1, description="The image model used for the generation")
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9)
|
||||
image_ref: Optional[list[LumaImageRef]] = Field(None, description="List of image reference objects")
|
||||
style_ref: Optional[list[LumaImageRef]] = Field(None, description="List of style reference objects")
|
||||
character_ref: Optional[LumaCharacterRef] = Field(None, description="The image identity object")
|
||||
modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description="The modify image reference object")
|
||||
|
||||
|
||||
class LumaGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The prompt of the generation')
|
||||
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation')
|
||||
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation')
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation')
|
||||
resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation')
|
||||
loop: Optional[bool] = Field(None, description='Whether to loop the video')
|
||||
keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation')
|
||||
concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation')
|
||||
prompt: str = Field(..., description="The prompt of the generation")
|
||||
model: LumaVideoModel = Field(LumaVideoModel.ray_2, description="The video model used for the generation")
|
||||
duration: Optional[LumaVideoModelOutputDuration] = Field(None, description="The duration of the generation")
|
||||
aspect_ratio: Optional[LumaAspectRatio] = Field(None, description="The aspect ratio of the generation")
|
||||
resolution: Optional[LumaVideoOutputResolution] = Field(None, description="The resolution of the generation")
|
||||
loop: Optional[bool] = Field(None, description="Whether to loop the video")
|
||||
keyframes: Optional[LumaKeyframes] = Field(None, description="The keyframes of the generation")
|
||||
concepts: Optional[list[LumaConceptObject]] = Field(None, description="Camera Concepts to apply to generation")
|
||||
|
||||
|
||||
class LumaGeneration(BaseModel):
|
||||
id: str = Field(..., description='The ID of the generation')
|
||||
generation_type: LumaGenerationType = Field(..., description='Generation type, image or video')
|
||||
state: LumaState = Field(..., description='The state of the generation')
|
||||
failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation')
|
||||
created_at: str = Field(..., description='The date and time when the generation was created')
|
||||
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
|
||||
model: str = Field(..., description='The model used for the generation')
|
||||
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
|
||||
id: str = Field(..., description="The ID of the generation")
|
||||
generation_type: LumaGenerationType = Field(..., description="Generation type, image or video")
|
||||
state: LumaState = Field(..., description="The state of the generation")
|
||||
failure_reason: Optional[str] = Field(None, description="The reason for the state of the generation")
|
||||
created_at: str = Field(..., description="The date and time when the generation was created")
|
||||
assets: Optional[LumaAssets] = Field(None, description="The assets of the generation")
|
||||
model: str = Field(..., description="The model used for the generation")
|
||||
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(...)
|
||||
|
||||
|
||||
class Luma2ImageRef(BaseModel):
|
||||
url: str | None = None
|
||||
data: str | None = None
|
||||
media_type: str | None = None
|
||||
generation_id: str | None = Field(None, description="reference a prior generation (extend / source reuse)")
|
||||
|
||||
|
||||
class Luma2VideoEdit(BaseModel):
|
||||
"""Edit controls for Ray 3.2 ``video_edit`` generations."""
|
||||
|
||||
auto_controls: bool | None = Field(None, description="derive a conditioning schedule from the source (recommended)")
|
||||
strength: str | None = Field(None, description="'adhere_1' .. 'reimagine_3'; constrained by IO.Combo")
|
||||
|
||||
|
||||
class Luma2VideoOptions(BaseModel):
|
||||
"""Ray 3.2 ``video`` output settings (text / image / keyframe / edit / extend)."""
|
||||
|
||||
resolution: str | None = Field(None, description="360p | 540p | 720p | 1080p")
|
||||
duration: str | None = Field(None, description="5s | 10s")
|
||||
loop: bool | None = Field(None)
|
||||
start_frame: Luma2ImageRef | None = Field(None)
|
||||
end_frame: Luma2ImageRef | None = Field(None)
|
||||
keyframes: list[Luma2ImageRef] | None = Field(None)
|
||||
keyframe_indexes: list[int] | None = Field(None)
|
||||
edit: Luma2VideoEdit | None = Field(None)
|
||||
|
||||
|
||||
class Luma2GenerationRequest(BaseModel):
|
||||
@ -266,6 +291,7 @@ class Luma2GenerationRequest(BaseModel):
|
||||
web_search: bool | None = None
|
||||
image_ref: list[Luma2ImageRef] | None = None
|
||||
source: Luma2ImageRef | None = None
|
||||
video: Luma2VideoOptions | None = Field(None)
|
||||
|
||||
|
||||
class Luma2Generation(BaseModel):
|
||||
@ -277,3 +303,31 @@ class Luma2Generation(BaseModel):
|
||||
output: list[LumaImageReference] | None = None
|
||||
failure_reason: str | None = None
|
||||
failure_code: str | None = None
|
||||
|
||||
|
||||
# --- Ray 3.2 multi-keyframe chain ---
|
||||
|
||||
LUMA_KEYFRAME_MODE_FRACTION = "fraction" # value in [0.0, 1.0] of the output video duration
|
||||
LUMA_KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the output
|
||||
|
||||
|
||||
class LumaRay32KeyframeItem:
|
||||
"""One guide image anchored at a position on the Ray 3.2 output timeline."""
|
||||
|
||||
def __init__(self, image: torch.Tensor, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # LUMA_KEYFRAME_MODE_FRACTION | LUMA_KEYFRAME_MODE_SECONDS
|
||||
self.value = value
|
||||
|
||||
|
||||
class LumaRay32KeyframeChain:
|
||||
def __init__(self):
|
||||
self.items: list[LumaRay32KeyframeItem] = []
|
||||
|
||||
def add(self, item: LumaRay32KeyframeItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "LumaRay32KeyframeChain":
|
||||
c = LumaRay32KeyframeChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
@ -67,15 +67,6 @@ class RunwayImageToVideoResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayTaskStatusEnum(str, Enum):
|
||||
SUCCEEDED = 'SUCCEEDED'
|
||||
RUNNING = 'RUNNING'
|
||||
FAILED = 'FAILED'
|
||||
PENDING = 'PENDING'
|
||||
CANCELLED = 'CANCELLED'
|
||||
THROTTLED = 'THROTTLED'
|
||||
|
||||
|
||||
class RunwayTaskStatusResponse(BaseModel):
|
||||
createdAt: datetime = Field(..., description='Task creation timestamp')
|
||||
id: str = Field(..., description='Task ID')
|
||||
@ -86,7 +77,7 @@ class RunwayTaskStatusResponse(BaseModel):
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
status: RunwayTaskStatusEnum
|
||||
status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED")
|
||||
|
||||
|
||||
class Model4(str, Enum):
|
||||
@ -125,3 +116,144 @@ class RunwayTextToImageRequest(BaseModel):
|
||||
|
||||
class RunwayTextToImageResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayAleph2IO:
|
||||
"""Custom socket types for chaining Aleph2 guidance images."""
|
||||
|
||||
KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME"
|
||||
PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE"
|
||||
|
||||
|
||||
# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to
|
||||
# choose the request model below. The values match the Aleph2 keyframe union field names.
|
||||
KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video
|
||||
KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration
|
||||
|
||||
# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`.
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video
|
||||
PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeItem:
|
||||
"""A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeChain:
|
||||
"""An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2KeyframeItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2KeyframeItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2KeyframeChain":
|
||||
c = RunwayAleph2KeyframeChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageItem:
|
||||
"""A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageChain:
|
||||
"""An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2PromptImageItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2PromptImageItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2PromptImageChain":
|
||||
c = RunwayAleph2PromptImageChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeSeconds(BaseModel):
|
||||
seconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.",
|
||||
ge=0.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeAt(BaseModel):
|
||||
at: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the input video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2TimestampPosition(BaseModel):
|
||||
type: str = Field(default="timestamp")
|
||||
timestampSeconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the output video.",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2RelativePosition(BaseModel):
|
||||
type: str = Field(default="position")
|
||||
positionPercentage: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the total output video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImage(BaseModel):
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2ContentModeration(BaseModel):
|
||||
publicFigureThreshold: str = Field(
|
||||
...,
|
||||
description='When set to "low", the content moderation system is less strict about '
|
||||
'recognizable public figures. One of "auto" or "low".',
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Request(BaseModel):
|
||||
model: str = Field(default="aleph2")
|
||||
promptText: str = Field(
|
||||
...,
|
||||
description="A non-empty string describing what should appear in the output.",
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
)
|
||||
videoUri: str = Field(...)
|
||||
seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295)
|
||||
contentModeration: RunwayAleph2ContentModeration = Field(...)
|
||||
keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field(
|
||||
None,
|
||||
description="Timed guidance images placed at specific points in the input video. Up to 5.",
|
||||
)
|
||||
promptImage: list[RunwayAleph2PromptImage] | None = Field(
|
||||
None,
|
||||
description="Up to 5 image keyframes for guiding the edit at specific points in the output video.",
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Response(BaseModel):
|
||||
id: str | None = Field(None, description="Task ID")
|
||||
|
||||
@ -208,6 +208,10 @@ class TripoMultiviewToModelRequest(BaseModel):
|
||||
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
|
||||
|
||||
|
||||
class TripoTexturePrompt(BaseModel):
|
||||
text: str | None = Field(None, description="Text guidance for texture generation")
|
||||
|
||||
|
||||
class TripoTextureModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task")
|
||||
original_model_task_id: str = Field(..., description="The task ID of the original model")
|
||||
@ -219,6 +223,11 @@ class TripoTextureModelRequest(BaseModel):
|
||||
texture_alignment: TripoTextureAlignment | None = Field(
|
||||
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
|
||||
)
|
||||
texture_prompt: TripoTexturePrompt | None = Field(
|
||||
None,
|
||||
description="Optional guidance for texturing. Required in practice for imported models, "
|
||||
"which carry no source image to infer texture from.",
|
||||
)
|
||||
|
||||
|
||||
class TripoRefineModelRequest(BaseModel):
|
||||
@ -307,6 +316,17 @@ class TripoP1MultiviewToModelRequest(TripoP1CommonRequest):
|
||||
orientation: str | None = None
|
||||
|
||||
|
||||
class TripoImportModelRequest(BaseModel):
|
||||
"""Request for the comfy-api composite import endpoint (/proxy/tripo/v2/openapi/import).
|
||||
|
||||
The model file is uploaded to ComfyUI API storage first; the backend downloads it from
|
||||
`url`, re-uploads it to Tripo's storage and creates the import_model task server-side.
|
||||
"""
|
||||
|
||||
url: str = Field(..., description="ComfyUI API storage download URL of the model file")
|
||||
format: str = Field(..., description='File format: "glb", "fbx", "obj" or "stl"')
|
||||
|
||||
|
||||
class TripoTaskOutput(BaseModel):
|
||||
model: str | None = Field(None, description="URL to the model")
|
||||
base_model: str | None = Field(None, description="URL to the base model")
|
||||
|
||||
@ -534,6 +534,15 @@ class FluxEraseNode(IO.ComfyNode):
|
||||
max=25,
|
||||
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
@ -553,6 +562,7 @@ class FluxEraseNode(IO.ComfyNode):
|
||||
image: Input.Image,
|
||||
mask: Input.Image,
|
||||
dilate_pixels: int = 10,
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=256, min_height=256)
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
@ -565,6 +575,7 @@ class FluxEraseNode(IO.ComfyNode):
|
||||
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
|
||||
mask=mask,
|
||||
dilate_pixels=dilate_pixels,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -12,6 +12,8 @@ from comfy_api_nodes.apis.bria import (
|
||||
BriaRemoveVideoBackgroundRequest,
|
||||
BriaRemoveVideoBackgroundResponse,
|
||||
BriaStatusResponse,
|
||||
BriaVideoGreenScreenRequest,
|
||||
BriaVideoReplaceBackgroundRequest,
|
||||
InputModerationSettings,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -287,7 +289,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -319,6 +321,161 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
class BriaVideoGreenScreen(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaVideoGreenScreen",
|
||||
display_name="Bria Video Green Screen",
|
||||
category="partner/video/Bria",
|
||||
description="Replace a video's background with a solid chroma-key screen using Bria.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Combo.Input(
|
||||
"green_shade",
|
||||
options=["broadcast_green", "chroma_green", "blue_screen"],
|
||||
tooltip="Solid chroma-key shade applied behind the foreground: "
|
||||
"broadcast_green (#00B140), chroma_green (#00FF00), or blue_screen (#0000FF).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
green_shade: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/green_screen", method="POST"),
|
||||
data=BriaVideoGreenScreenRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
green_shade=green_shade,
|
||||
output_container_and_codec="mp4_h264",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
class BriaVideoReplaceBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaVideoReplaceBackground",
|
||||
display_name="Bria Video Replace Background",
|
||||
category="partner/video/Bria",
|
||||
description="Replace a video's background with a supplied image or video using Bria. "
|
||||
"The output keeps the foreground's resolution and frame rate; a background with a "
|
||||
"different aspect ratio is stretched to fit, so match it for undistorted results.",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="Foreground video whose background is replaced."),
|
||||
IO.Image.Input(
|
||||
"background_image",
|
||||
optional=True,
|
||||
tooltip="Background image to composite behind the foreground. "
|
||||
"Provide either a background image or a background video, not both.",
|
||||
),
|
||||
IO.Video.Input(
|
||||
"background_video",
|
||||
optional=True,
|
||||
tooltip="Background video to composite behind the foreground. "
|
||||
"Provide either a background image or a background video, not both.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
background_image: Input.Image | None = None,
|
||||
background_video: Input.Video | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if (background_image is None) == (background_video is None):
|
||||
raise ValueError("Provide either a background image or a background video, not both.")
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
if background_video is not None:
|
||||
validate_video_duration(background_video, max_duration=60.0)
|
||||
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
|
||||
else:
|
||||
# Bria's replace_background 500s on RGBA, so drop the alpha channel before upload.
|
||||
background_url = await upload_image_to_comfyapi(
|
||||
cls, background_image[:, :, :, :3], wait_label="Uploading background"
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
|
||||
data=BriaVideoReplaceBackgroundRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
background_url=background_url,
|
||||
output_container_and_codec="mp4_h264",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]:
|
||||
"""Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask.
|
||||
|
||||
@ -376,7 +533,7 @@ class BriaTransparentVideoBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -416,6 +573,8 @@ class BriaExtension(ComfyExtension):
|
||||
BriaImageEditNode,
|
||||
BriaRemoveImageBackground,
|
||||
BriaRemoveVideoBackground,
|
||||
BriaVideoGreenScreen,
|
||||
BriaVideoReplaceBackground,
|
||||
BriaTransparentVideoBackground,
|
||||
]
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from io import BytesIO
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
@ -131,6 +132,44 @@ def _prepare_seedance_image(image: Input.Image) -> Input.Image:
|
||||
return image
|
||||
|
||||
|
||||
# Supported output aspect ratios, used to pre-size FLF frames to matching pixel pair to avoid the 1080p stretch jump.
|
||||
SEEDANCE2_RATIO_WH = {
|
||||
"16:9": (16, 9),
|
||||
"4:3": (4, 3),
|
||||
"1:1": (1, 1),
|
||||
"3:4": (3, 4),
|
||||
"9:16": (9, 16),
|
||||
"21:9": (21, 9),
|
||||
}
|
||||
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
|
||||
|
||||
|
||||
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
|
||||
"""Exact supported output (width, height) for (resolution, ratio).
|
||||
|
||||
The shorter side equals the resolution number (e.g. 1080p 16:9 -> 1920x1080). For ratio
|
||||
"adaptive" (or any unexpected value) the ratio is derived from the image's own aspect, snapped
|
||||
to the nearest supported ratio, so the output keeps the frame's orientation.
|
||||
"""
|
||||
short = SEEDANCE2_RES_SHORT_SIDE[resolution]
|
||||
if ratio not in SEEDANCE2_RATIO_WH:
|
||||
aspect = image.shape[-2] / image.shape[-3] # W / H; tensor is (B, H, W, C)
|
||||
ratio = min(SEEDANCE2_RATIO_WH, key=lambda k: abs(SEEDANCE2_RATIO_WH[k][0] / SEEDANCE2_RATIO_WH[k][1] - aspect))
|
||||
rw, rh = SEEDANCE2_RATIO_WH[ratio]
|
||||
if rw >= rh: # landscape or square: shorter side is the height
|
||||
out_w, out_h = round(short * rw / rh), short
|
||||
else: # portrait: shorter side is the width
|
||||
out_w, out_h = short, round(short * rh / rw)
|
||||
return out_w - out_w % 2, out_h - out_h % 2
|
||||
|
||||
|
||||
def _resize_to_exact(image: torch.Tensor, width: int, height: int) -> torch.Tensor:
|
||||
"""Center-crop to the target aspect and resize to exactly width x height (lanczos)."""
|
||||
samples = image.movedim(-1, 1) # (B, H, W, C) -> (B, C, H, W)
|
||||
resized = common_upscale(samples, width, height, "lanczos", "center")
|
||||
return resized.movedim(1, -1)
|
||||
|
||||
|
||||
async def _resolve_reference_assets(
|
||||
cls: type[IO.ComfyNode],
|
||||
asset_ids: list[str],
|
||||
@ -1790,10 +1829,28 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
if last_frame is not None and last_frame_asset_id:
|
||||
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
|
||||
|
||||
if first_frame is not None:
|
||||
first_frame = _prepare_seedance_image(first_frame)
|
||||
if last_frame is not None:
|
||||
last_frame = _prepare_seedance_image(last_frame)
|
||||
request_ratio = model["ratio"]
|
||||
if first_frame_asset_id or last_frame_asset_id:
|
||||
if first_frame is not None:
|
||||
first_frame = _prepare_seedance_image(first_frame)
|
||||
if last_frame is not None:
|
||||
last_frame = _prepare_seedance_image(last_frame)
|
||||
else:
|
||||
# The 1080p FLF stretch fix (pre-size frames to a supported pixel pair + submit ratio="adaptive")
|
||||
# only applies to local image inputs we can resize.
|
||||
request_ratio = "adaptive"
|
||||
target_dims: tuple[int, int] | None = None
|
||||
if first_frame is not None:
|
||||
validate_image_aspect_ratio(first_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], first_frame)
|
||||
first_frame = _resize_to_exact(first_frame, *target_dims)
|
||||
if last_frame is not None:
|
||||
validate_image_aspect_ratio(last_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
validate_image_dimensions(last_frame, min_width=300, min_height=300)
|
||||
if target_dims is None:
|
||||
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], last_frame)
|
||||
last_frame = _resize_to_exact(last_frame, *target_dims)
|
||||
|
||||
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
|
||||
image_assets: dict[str, str] = {}
|
||||
@ -1844,7 +1901,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
content=content,
|
||||
generate_audio=model["generate_audio"],
|
||||
resolution=model["resolution"],
|
||||
ratio=model["ratio"],
|
||||
ratio=request_ratio,
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
|
||||
@ -5,10 +5,9 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
|
||||
import base64
|
||||
import os
|
||||
from enum import Enum
|
||||
from fnmatch import fnmatch
|
||||
from io import BytesIO
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
@ -19,6 +18,7 @@ from comfy_api_nodes.apis.gemini import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerationConfig,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiImageConfig,
|
||||
GeminiImageGenerateContentRequest,
|
||||
@ -40,13 +40,18 @@ from comfy_api_nodes.util import (
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
video_to_base64_string,
|
||||
)
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
GEMINI_URL_INPUT_BUDGET = 10
|
||||
GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024
|
||||
GEMINI_IMAGE_SYS_PROMPT = (
|
||||
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||
"Interpret all user input—regardless of "
|
||||
@ -72,15 +77,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
||||
)
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
"""
|
||||
Gemini Image Model Names allowed by comfy-api
|
||||
"""
|
||||
|
||||
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
|
||||
gemini_2_5_flash_image = "gemini-2.5-flash-image"
|
||||
|
||||
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: Input.Image | list[Input.Image],
|
||||
@ -237,21 +233,15 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
if not response.modelVersion:
|
||||
return None
|
||||
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
|
||||
if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"):
|
||||
if response.modelVersion == "gemini-2.5-pro":
|
||||
input_tokens_price = 1.25
|
||||
output_text_tokens_price = 10.0
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion in (
|
||||
"gemini-2.5-flash-preview-04-17",
|
||||
"gemini-2.5-flash",
|
||||
):
|
||||
elif response.modelVersion == "gemini-2.5-flash":
|
||||
input_tokens_price = 0.30
|
||||
output_text_tokens_price = 2.50
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion in (
|
||||
"gemini-2.5-flash-image-preview",
|
||||
"gemini-2.5-flash-image",
|
||||
):
|
||||
elif response.modelVersion == "gemini-2.5-flash-image":
|
||||
input_tokens_price = 0.30
|
||||
output_text_tokens_price = 2.50
|
||||
output_image_tokens_price = 30.0
|
||||
@ -285,6 +275,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
return final_price / 1_000_000.0
|
||||
|
||||
|
||||
def create_video_parts(video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert a single video input to Gemini API compatible parts (inline MP4/H.264)."""
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.video_mp4,
|
||||
data=base_64_string,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]:
|
||||
"""Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item)."""
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = Input.Audio(
|
||||
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
sample_rate=audio_input["sample_rate"],
|
||||
)
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
)
|
||||
audio_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.audio_mp3,
|
||||
data=audio_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
|
||||
def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]:
|
||||
"""Expand any batched image tensors into individual (H, W, C) frames, preserving order."""
|
||||
frames: list[torch.Tensor] = []
|
||||
for img in images:
|
||||
if len(img.shape) == 4:
|
||||
frames.extend(img[i] for i in range(img.shape[0]))
|
||||
else:
|
||||
frames.append(img)
|
||||
return frames
|
||||
|
||||
|
||||
def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]:
|
||||
"""Expand any batched audio inputs into individual single-clip audio inputs, preserving order."""
|
||||
clips: list[Input.Audio] = []
|
||||
for audio in audios:
|
||||
waveform = audio["waveform"]
|
||||
for i in range(waveform.shape[0]):
|
||||
clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"]))
|
||||
return clips
|
||||
|
||||
|
||||
async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart:
|
||||
"""Upload a single media unit to ComfyAPI storage and return a fileData (URL) part."""
|
||||
if kind == "image":
|
||||
url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image")
|
||||
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url))
|
||||
if kind == "audio":
|
||||
url = await upload_audio_to_comfyapi(
|
||||
cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3"
|
||||
)
|
||||
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url))
|
||||
url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video")
|
||||
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url))
|
||||
|
||||
|
||||
def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]:
|
||||
"""Encode a single media unit as an inline base64 part; returns (part, base64_length)."""
|
||||
if kind == "image":
|
||||
data = tensor_to_base64_string(payload, mime_type="image/webp")
|
||||
mime = GeminiMimeType.image_webp
|
||||
elif kind == "audio":
|
||||
data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame")
|
||||
mime = GeminiMimeType.audio_mp3
|
||||
else:
|
||||
data = video_to_base64_string(
|
||||
payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
mime = GeminiMimeType.video_mp4
|
||||
return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data)
|
||||
|
||||
|
||||
async def build_gemini_media_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: list[Input.Image],
|
||||
audios: list[Input.Audio],
|
||||
videos: list[Input.Video],
|
||||
*,
|
||||
url_budget: int = GEMINI_URL_INPUT_BUDGET,
|
||||
max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES,
|
||||
) -> list[GeminiPart]:
|
||||
"""Build Gemini parts for multimodal inputs (images, audio, video).
|
||||
|
||||
fileData URLs are preferred for every media type: the upload is fetched directly by the
|
||||
model, keeping the request body tiny regardless of media size. The URL budget is shared
|
||||
across all media and assigned largest-first (video, then audio, then images), so that if it
|
||||
is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline
|
||||
payload is capped by `max_inline_bytes`.
|
||||
"""
|
||||
units: list[tuple[str, Any]] = (
|
||||
[("video", v) for v in videos]
|
||||
+ [("audio", a) for a in _flatten_audio(audios)]
|
||||
+ [("image", f) for f in _flatten_images(images)]
|
||||
)
|
||||
|
||||
parts: list[GeminiPart] = []
|
||||
url_used = 0
|
||||
inline_bytes = 0
|
||||
for kind, payload in units:
|
||||
if url_used < url_budget:
|
||||
parts.append(await _media_url_part(cls, kind, payload))
|
||||
url_used += 1
|
||||
continue
|
||||
part, nbytes = _media_inline_part(kind, payload)
|
||||
inline_bytes += nbytes
|
||||
if inline_bytes > max_inline_bytes:
|
||||
raise ValueError(
|
||||
f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first "
|
||||
f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media."
|
||||
)
|
||||
parts.append(part)
|
||||
return parts
|
||||
|
||||
|
||||
class GeminiNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
@ -315,8 +439,6 @@ class GeminiNode(IO.ComfyNode):
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
"gemini-2.5-flash-preview-04-17",
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-3-pro-preview",
|
||||
@ -407,58 +529,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
)
|
||||
""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.video_mp4,
|
||||
data=base_64_string,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert audio input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded audio.
|
||||
"""
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = Input.Audio(
|
||||
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
sample_rate=audio_input["sample_rate"],
|
||||
)
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
)
|
||||
audio_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.audio_mp3,
|
||||
data=audio_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
@ -482,9 +555,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
if images is not None:
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if audio is not None:
|
||||
parts.extend(cls.create_audio_parts(audio))
|
||||
parts.extend(create_audio_parts(audio))
|
||||
if video is not None:
|
||||
parts.extend(cls.create_video_parts(video))
|
||||
parts.extend(create_video_parts(video))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
@ -512,6 +585,210 @@ class GeminiNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
|
||||
|
||||
GEMINI_V2_MODELS: dict[str, str] = {
|
||||
"Gemini 3.1 Pro": "gemini-3.1-pro-preview",
|
||||
"Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview",
|
||||
}
|
||||
|
||||
|
||||
def _gemini_text_model_inputs(thinking_default: str) -> list[Input]:
|
||||
"""Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls)."""
|
||||
return [
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 17)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional image(s) to use as context for the model. Up to 16 images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"audio",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Audio.Input("audio"),
|
||||
names=["audio_1"],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional audio clip to use as context for the model.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"video",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=["video_1"],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional video clip to use as context for the model.",
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Input Files node.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"thinking_level",
|
||||
options=["LOW", "HIGH"],
|
||||
default=thinking_default,
|
||||
tooltip="How hard the model reasons internally before answering. "
|
||||
"HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"max_output_tokens",
|
||||
default=32768,
|
||||
min=16,
|
||||
max=65536,
|
||||
tooltip="Maximum tokens to generate, including the model's internal thinking. "
|
||||
"With thinking_level HIGH, a low value can leave no room for the answer; raise this if "
|
||||
"responses come back empty or truncated. The model stops early when finished, so a higher "
|
||||
"cap costs nothing extra for short replies.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GeminiNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNodeV2",
|
||||
display_name="Google Gemini",
|
||||
category="partner/text/Gemini",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with Google's Gemini models. Provide a text prompt and, "
|
||||
"optionally, one or more images, audio clips, videos, or files as multimodal context.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text input to the model. Include detailed instructions, questions, or context.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")),
|
||||
IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")),
|
||||
],
|
||||
tooltip="The Gemini model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Foundational instructions that dictate the model's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m, "lite") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00025, 0.0015],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
} : {
|
||||
"type": "list_usd",
|
||||
"usd": [0.002, 0.012],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = GEMINI_V2_MODELS[model["model"]]
|
||||
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
images = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
audios = [a for a in (model.get("audio") or {}).values() if a is not None]
|
||||
videos = [v for v in (model.get("video") or {}).values() if v is not None]
|
||||
if images or audios or videos:
|
||||
parts.extend(await build_gemini_media_parts(cls, images, audios, videos))
|
||||
files = model.get("files")
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
|
||||
data=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role=GeminiRole.user,
|
||||
parts=parts,
|
||||
)
|
||||
],
|
||||
generationConfig=GeminiGenerationConfig(
|
||||
temperature=model["temperature"],
|
||||
topP=model["top_p"],
|
||||
maxOutputTokens=model["max_output_tokens"],
|
||||
seed=seed if seed > 0 else None,
|
||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
output_text = get_text_from_response(response)
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
|
||||
|
||||
class GeminiInputFiles(IO.ComfyNode):
|
||||
"""
|
||||
Loads and formats input files for use with the Gemini API.
|
||||
@ -609,8 +886,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=GeminiImageModel,
|
||||
default=GeminiImageModel.gemini_2_5_flash_image,
|
||||
options=["gemini-2.5-flash-image"],
|
||||
tooltip="The Gemini model to use for generating responses.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
@ -1129,6 +1405,26 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
optional=True,
|
||||
tooltip="Controls randomness in generation. Lower is more focused/deterministic.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
optional=True,
|
||||
tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -1165,6 +1461,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
seed: int,
|
||||
response_modalities: str,
|
||||
system_prompt: str = "",
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 0.95,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_choice = model["model"]
|
||||
@ -1204,6 +1502,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||
temperature=temperature,
|
||||
topP=top_p,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
@ -1222,6 +1522,7 @@ class GeminiExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
GeminiNode,
|
||||
GeminiNodeV2,
|
||||
GeminiImage,
|
||||
GeminiImage2,
|
||||
GeminiNanoBanana2,
|
||||
|
||||
@ -30,7 +30,7 @@ from comfy_api_nodes.util import (
|
||||
|
||||
|
||||
_GROK_VIDEO_MODEL_API_IDS = {
|
||||
"grok-imagine-video-1.5": "grok-imagine-video-1.5-preview",
|
||||
"grok-imagine-video-1.5": "grok-imagine-video-1.5",
|
||||
}
|
||||
|
||||
|
||||
@ -521,8 +521,8 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["480p", "720p"],
|
||||
tooltip="The resolution of the output video.",
|
||||
options=["480p", "720p", "1080p"],
|
||||
tooltip="The resolution of the output video. 1080p is only available for grok-imagine-video-1.5.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
@ -570,11 +570,12 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
(
|
||||
$is15 := $contains(widgets.model, "1.5");
|
||||
$rate := $is15
|
||||
? (widgets.resolution = "720p" ? 0.2002 : 0.1144)
|
||||
? (widgets.resolution = "1080p" ? 0.25 : (widgets.resolution = "720p" ? 0.14 : 0.08))
|
||||
: (widgets.resolution = "720p" ? 0.07 : 0.05);
|
||||
$imgCost := $is15 ? 0.0143 : 0.002;
|
||||
$imgCost := $is15 ? 0.01 : 0.002;
|
||||
$base := $rate * widgets.duration;
|
||||
{"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base}
|
||||
$total := inputs.image.connected ? $base + $imgCost : $base;
|
||||
{"type":"usd","usd": $is15 ? $total * 1.43 : $total}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -593,6 +594,8 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
) -> IO.NodeOutput:
|
||||
if image is None and model == "grok-imagine-video-1.5":
|
||||
raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.")
|
||||
if resolution == "1080p" and model != "grok-imagine-video-1.5":
|
||||
raise ValueError(f"1080p resolution is only available for grok-imagine-video-1.5, not '{model}'.")
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
|
||||
@ -60,6 +60,12 @@ from comfy_api_nodes.apis.kling import (
|
||||
OmniProImageRequest,
|
||||
OmniProReferences2VideoRequest,
|
||||
OmniProText2VideoRequest,
|
||||
Kling3TurboSettings,
|
||||
Kling3TurboText2VideoRequest,
|
||||
Kling3TurboContent,
|
||||
Kling3TurboImage2VideoRequest,
|
||||
Kling3TurboCreateResponse,
|
||||
Kling3TurboQueryResponse,
|
||||
TaskStatusResponse,
|
||||
TextToVideoWithAudioRequest,
|
||||
)
|
||||
@ -436,7 +442,7 @@ async def execute_text2video(
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
mode=KlingVideoGenMode(model_mode),
|
||||
model_name=KlingVideoGenModelName(model_name),
|
||||
model_name=model_name,
|
||||
cfg_scale=cfg_scale,
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
camera_control=camera_control,
|
||||
@ -2847,6 +2853,67 @@ class MotionControl(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
def build_turbo_shot_prompt(multi_prompt: list[MultiPromptEntry]) -> str:
|
||||
"""Render storyboard entries into the Turbo multi-shot prompt 'shot n, m, words; ...'."""
|
||||
return "; ".join(f"shot {i}, {int(e.duration)}, {e.prompt}" for i, e in enumerate(multi_prompt, 1)) + ";"
|
||||
|
||||
|
||||
def _turbo_video_url(response: Kling3TurboQueryResponse) -> str:
|
||||
"""Extract the result video URL from a /tasks response (data[].outputs[] where type == 'video')."""
|
||||
task = response.data[0] if response.data else None
|
||||
if task and task.outputs:
|
||||
for output in task.outputs:
|
||||
if output.type == "video" and output.url:
|
||||
return output.url
|
||||
raise RuntimeError(f"Kling 3.0 Turbo task finished without a video output: {response.model_dump()}")
|
||||
|
||||
|
||||
async def execute_kling_turbo(
|
||||
cls: type[IO.ComfyNode],
|
||||
*,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
start_frame: torch.Tensor | None,
|
||||
) -> IO.NodeOutput:
|
||||
"""Create + poll a Kling 3.0 Turbo task. Image-to-video when start_frame is given, else text-to-video."""
|
||||
if start_frame is not None:
|
||||
validate_image_dimensions(start_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
|
||||
contents = [Kling3TurboContent(type="first_frame", url=tensor_to_base64_string(start_frame))]
|
||||
if prompt:
|
||||
contents.insert(0, Kling3TurboContent(type="prompt", text=prompt))
|
||||
create = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/image-to-video/kling-3.0-turbo", method="POST"),
|
||||
response_model=Kling3TurboCreateResponse,
|
||||
data=Kling3TurboImage2VideoRequest(
|
||||
contents=contents,
|
||||
settings=Kling3TurboSettings(resolution=resolution, duration=duration), # i2v: no aspect_ratio
|
||||
),
|
||||
)
|
||||
else:
|
||||
create = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/text-to-video/kling-3.0-turbo", method="POST"),
|
||||
response_model=Kling3TurboCreateResponse,
|
||||
data=Kling3TurboText2VideoRequest(
|
||||
prompt=prompt,
|
||||
settings=Kling3TurboSettings(resolution=resolution, aspect_ratio=aspect_ratio, duration=duration),
|
||||
),
|
||||
)
|
||||
if not (create.data and create.data.id):
|
||||
raise RuntimeError(f"Kling 3.0 Turbo create failed. Code: {create.code}, Message: {create.message}")
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/tasks", query_params={"task_ids": create.data.id}),
|
||||
response_model=Kling3TurboQueryResponse,
|
||||
status_extractor=lambda r: (r.data[0].status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(_turbo_video_url(final_response)))
|
||||
|
||||
|
||||
class KlingVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -2884,7 +2951,11 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
],
|
||||
tooltip="Generate a series of video segments with individual prompts and durations.",
|
||||
),
|
||||
IO.Boolean.Input("generate_audio", default=True),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=True,
|
||||
tooltip="'kling-3.0-turbo' always generates native audio, so the audio toggle is ignored.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
@ -2899,6 +2970,17 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"kling-3.0-turbo",
|
||||
[
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], default="720p"),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16", "1:1"],
|
||||
tooltip="Ignored in image-to-video mode.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Model and generation settings.",
|
||||
),
|
||||
@ -2930,6 +3012,7 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=[
|
||||
"model",
|
||||
"model.resolution",
|
||||
"generate_audio",
|
||||
"multi_shot",
|
||||
@ -2944,14 +3027,7 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$rates := {
|
||||
"4k": {"off": 0.42, "on": 0.42},
|
||||
"1080p": {"off": 0.112, "on": 0.168},
|
||||
"720p": {"off": 0.084, "on": 0.126}
|
||||
};
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$audio := widgets.generate_audio ? "on" : "off";
|
||||
$rate := $lookup($lookup($rates, $res), $audio);
|
||||
$ms := widgets.multi_shot;
|
||||
$isSb := $ms != "disabled";
|
||||
$n := $isSb ? $number($substring($ms, 0, 1)) : 0;
|
||||
@ -2962,7 +3038,18 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
$d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0;
|
||||
$d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0;
|
||||
$dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration");
|
||||
{"type":"usd","usd": $rate * $dur}
|
||||
widgets.model = "kling-3.0-turbo"
|
||||
? {"type":"usd","usd": ($res = "1080p" ? 0.14 : 0.112) * $dur}
|
||||
: (
|
||||
$rates := {
|
||||
"4k": {"off": 0.42, "on": 0.42},
|
||||
"1080p": {"off": 0.112, "on": 0.168},
|
||||
"720p": {"off": 0.084, "on": 0.126}
|
||||
};
|
||||
$audio := widgets.generate_audio ? "on" : "off";
|
||||
$rate := $lookup($lookup($rates, $res), $audio);
|
||||
{"type":"usd","usd": $rate * $dur}
|
||||
)
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -3015,6 +3102,17 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
duration = multi_shot["duration"]
|
||||
validate_string(multi_shot["prompt"], min_length=1, max_length=2500)
|
||||
|
||||
if model["model"] == "kling-3.0-turbo":
|
||||
turbo_prompt = build_turbo_shot_prompt(multi_prompt_list) if custom_multi_shot else multi_shot["prompt"]
|
||||
return await execute_kling_turbo(
|
||||
cls,
|
||||
prompt=turbo_prompt,
|
||||
resolution=model["resolution"],
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
duration=duration,
|
||||
start_frame=start_frame,
|
||||
)
|
||||
|
||||
if start_frame is not None:
|
||||
validate_image_dimensions(start_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
|
||||
|
||||
@ -42,9 +42,11 @@ async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Ima
|
||||
|
||||
|
||||
_MODEL_MEDIUM = "Krea 2 Medium"
|
||||
_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo"
|
||||
_MODEL_LARGE = "Krea 2 Large"
|
||||
_MODEL_ENDPOINTS: dict[str, str] = {
|
||||
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
|
||||
_MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo",
|
||||
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
|
||||
}
|
||||
|
||||
@ -57,7 +59,7 @@ _UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F
|
||||
|
||||
|
||||
def _krea_model_inputs() -> list:
|
||||
"""Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
|
||||
"""Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
@ -123,6 +125,7 @@ class Krea2ImageNode(IO.ComfyNode):
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
|
||||
IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()),
|
||||
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
|
||||
],
|
||||
tooltip="Krea 2 Medium is best for expressive illustrations; "
|
||||
@ -151,14 +154,15 @@ class Krea2ImageNode(IO.ComfyNode):
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isLarge := widgets.model = "krea 2 large";
|
||||
$rates := {
|
||||
"krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02},
|
||||
"krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04},
|
||||
"krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07}
|
||||
};
|
||||
$r := $lookup($rates, widgets.model);
|
||||
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
|
||||
$hasStyle := $lookup(inputs, "model.style_reference").connected;
|
||||
$usd := $hasMoodboard
|
||||
? ($isLarge ? 0.07 : 0.04)
|
||||
: ($hasStyle
|
||||
? ($isLarge ? 0.065 : 0.035)
|
||||
: ($isLarge ? 0.06 : 0.03));
|
||||
$usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text);
|
||||
{"type":"usd","usd": $usd}
|
||||
)
|
||||
""",
|
||||
|
||||
@ -3,9 +3,13 @@ from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.luma import (
|
||||
LUMA_KEYFRAME_MODE_FRACTION,
|
||||
LUMA_KEYFRAME_MODE_SECONDS,
|
||||
Luma2Generation,
|
||||
Luma2GenerationRequest,
|
||||
Luma2ImageRef,
|
||||
Luma2VideoEdit,
|
||||
Luma2VideoOptions,
|
||||
LumaAspectRatio,
|
||||
LumaCharacterRef,
|
||||
LumaConceptChain,
|
||||
@ -18,6 +22,8 @@ from comfy_api_nodes.apis.luma import (
|
||||
LumaIO,
|
||||
LumaKeyframes,
|
||||
LumaModifyImageRef,
|
||||
LumaRay32KeyframeChain,
|
||||
LumaRay32KeyframeItem,
|
||||
LumaReference,
|
||||
LumaReferenceChain,
|
||||
LumaVideoModel,
|
||||
@ -33,6 +39,7 @@ from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
@ -692,7 +699,10 @@ async def _luma2_upload_image_refs(
|
||||
async def _luma2_submit_and_poll(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: Luma2GenerationRequest,
|
||||
) -> Input.Image:
|
||||
*,
|
||||
estimated_duration: int | None = None,
|
||||
) -> Luma2Generation:
|
||||
"""Submit a Luma Agents generation and poll until done; returns the completed generation."""
|
||||
initial = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/luma_2/generations", method="POST"),
|
||||
@ -700,21 +710,21 @@ async def _luma2_submit_and_poll(
|
||||
data=request,
|
||||
)
|
||||
if not initial.id:
|
||||
raise RuntimeError("Luma 2 API did not return a generation id.")
|
||||
raise RuntimeError("Luma API did not return a generation id.")
|
||||
final = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"),
|
||||
response_model=Luma2Generation,
|
||||
status_extractor=lambda r: r.state,
|
||||
progress_extractor=lambda r: None,
|
||||
estimated_duration=estimated_duration,
|
||||
)
|
||||
if not final.output:
|
||||
if not final.output or not final.output[0].url:
|
||||
msg = final.failure_reason or "no output returned"
|
||||
raise RuntimeError(f"Luma 2 generation failed: {msg}")
|
||||
url = final.output[0].url
|
||||
if not url:
|
||||
raise RuntimeError("Luma 2 generation completed without an output URL.")
|
||||
return await download_url_to_image_tensor(url)
|
||||
if final.failure_code:
|
||||
msg = f"{msg} [{final.failure_code}]"
|
||||
raise RuntimeError(f"Luma generation failed: {msg}")
|
||||
return final
|
||||
|
||||
|
||||
class LumaImageNode(IO.ComfyNode):
|
||||
@ -843,7 +853,8 @@ class LumaImageNode(IO.ComfyNode):
|
||||
web_search=model["web_search"],
|
||||
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9),
|
||||
)
|
||||
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
|
||||
final = await _luma2_submit_and_poll(cls, request)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
|
||||
|
||||
|
||||
class LumaImageEditNode(IO.ComfyNode):
|
||||
@ -929,7 +940,533 @@ class LumaImageEditNode(IO.ComfyNode):
|
||||
web_search=model["web_search"],
|
||||
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8),
|
||||
)
|
||||
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
|
||||
final = await _luma2_submit_and_poll(cls, request)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final.output[0].url))
|
||||
|
||||
|
||||
_BADGE_RAY32_VIDEO = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]),
|
||||
expr="""
|
||||
(
|
||||
$p := {
|
||||
"360p": {"5s": 0.06, "10s": 0.18},
|
||||
"540p": {"5s": 0.15, "10s": 0.45},
|
||||
"720p": {"5s": 0.3, "10s": 0.9},
|
||||
"1080p": {"5s": 1.2, "10s": 3.6}
|
||||
};
|
||||
{"type": "usd", "usd": $lookup($lookup($p, widgets.resolution), widgets.duration)}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
_BADGE_RAY32_VIDEO_5S = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$p := {"360p": 0.06, "540p": 0.15, "720p": 0.3, "1080p": 1.2};
|
||||
{"type": "usd", "usd": $lookup($p, widgets.resolution)}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
_BADGE_RAY32_EDIT = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$p := {
|
||||
"360p": {"min": 0.54, "max": 1.08},
|
||||
"540p": {"min": 0.72, "max": 1.44},
|
||||
"720p": {"min": 1.08, "max": 2.16},
|
||||
"1080p": {"min": 2.16, "max": 4.32}
|
||||
};
|
||||
$r := $lookup($p, widgets.resolution);
|
||||
{"type": "range_usd", "min_usd": $r.min, "max_usd": $r.max, "format": {"note": "(by source length)"}}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
_BADGE_RAY32_REFRAME = IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$p := {"360p": 0.03, "540p": 0.06, "720p": 0.12, "1080p": 0.36};
|
||||
{"type": "usd", "usd": $lookup($p, widgets.resolution), "format": {"suffix": "/second"}}
|
||||
)
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def _ray32_seed_input() -> IO.Input:
|
||||
return IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed to determine if node should re-run; results are nondeterministic regardless of seed.",
|
||||
)
|
||||
|
||||
|
||||
async def _ray32_generate(cls: type[IO.ComfyNode], request: Luma2GenerationRequest) -> IO.NodeOutput:
|
||||
"""Run a ray-3.2 generation and return (video, generation_id)."""
|
||||
final = await _luma2_submit_and_poll(cls, request, estimated_duration=120)
|
||||
video = await download_url_to_video_output(final.output[0].url)
|
||||
return IO.NodeOutput(video, final.id or "")
|
||||
|
||||
|
||||
class LumaRay32TextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32TextToVideoNode",
|
||||
display_name="Luma Ray 3.2 Text to Video",
|
||||
category="partner/video/Luma",
|
||||
description="Generate a video from a text prompt using Luma's Ray 3.2 model.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
IO.Combo.Input("duration", options=["5s", "10s"]),
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
tooltip="Make the video loop seamlessly. Only available with 5s duration.",
|
||||
),
|
||||
_ray32_seed_input(),
|
||||
],
|
||||
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, prompt: str, aspect_ratio: str, resolution: str, duration: str, loop: bool, seed: int
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
|
||||
if loop and duration == "10s":
|
||||
raise ValueError("Looping is only available with 5s duration on Ray 3.2.")
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model="ray-3.2",
|
||||
type="video",
|
||||
aspect_ratio=aspect_ratio,
|
||||
video=Luma2VideoOptions(resolution=resolution, duration=duration, loop=loop or None),
|
||||
)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32ImageToVideoNode",
|
||||
display_name="Luma Ray 3.2 Image to Video",
|
||||
category="partner/video/Luma",
|
||||
description="Generate a video from a start and/or end frame using Luma's Ray 3.2 model. "
|
||||
"Image-anchored generations are always 5 seconds.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
tooltip="Make the video loop seamlessly. Not available when an end_frame is set.",
|
||||
),
|
||||
_ray32_seed_input(),
|
||||
IO.Image.Input("start_frame", optional=True, tooltip="First frame of the generated video."),
|
||||
IO.Image.Input("end_frame", optional=True, tooltip="Last frame of the generated video."),
|
||||
],
|
||||
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_VIDEO_5S,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
loop: bool,
|
||||
seed: int,
|
||||
start_frame: torch.Tensor | None = None,
|
||||
end_frame: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
|
||||
if start_frame is None and end_frame is None:
|
||||
raise ValueError("Provide at least one of start_frame / end_frame.")
|
||||
if loop and end_frame is not None:
|
||||
raise ValueError("Looping is not available when an end_frame is set.")
|
||||
video = Luma2VideoOptions(resolution=resolution, duration="5s", loop=loop or None)
|
||||
if start_frame is not None:
|
||||
url = await upload_image_to_comfyapi(cls, start_frame, mime_type="image/png")
|
||||
video.start_frame = Luma2ImageRef(url=url)
|
||||
if end_frame is not None:
|
||||
url = await upload_image_to_comfyapi(cls, end_frame, mime_type="image/png")
|
||||
video.end_frame = Luma2ImageRef(url=url)
|
||||
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32KeyframeNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32KeyframeNode",
|
||||
display_name="Luma Ray 3.2 Keyframe",
|
||||
category="partner/video/Luma",
|
||||
description="Anchor a guide image to a position on the Ray 3.2 output video timeline. Connect this to "
|
||||
"the 'keyframes' input of the Luma Ray 3.2 Keyframes to Video node; chain several together via the "
|
||||
"optional 'keyframes' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Guide image to place at the chosen moment of the output video."),
|
||||
IO.DynamicCombo.Input(
|
||||
"position",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Fraction of duration (0.0-1.0)",
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the output video this image applies " "(0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"Absolute time (seconds)",
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=10.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from the start of the output video where this "
|
||||
"image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the output video's timeline.",
|
||||
),
|
||||
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Optional earlier keyframes to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Output(display_name="keyframes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
position: dict,
|
||||
keyframes: LumaRay32KeyframeChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = keyframes.clone() if keyframes is not None else LumaRay32KeyframeChain()
|
||||
if position["position"] == "Absolute time (seconds)":
|
||||
mode, value = LUMA_KEYFRAME_MODE_SECONDS, float(position["seconds"])
|
||||
else:
|
||||
mode, value = LUMA_KEYFRAME_MODE_FRACTION, float(position["fraction"])
|
||||
chain.add(LumaRay32KeyframeItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class LumaRay32KeyframesToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32KeyframesToVideoNode",
|
||||
display_name="Luma Ray 3.2 Keyframes to Video",
|
||||
category="partner/video/Luma",
|
||||
description="Generate a video that interpolates through a sequence of guide images, each anchored to a "
|
||||
"position on the timeline, using Luma Ray 3.2. Build the sequence with Luma Ray 3.2 Keyframe nodes "
|
||||
"(at least 2).",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the video generation."),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
IO.Combo.Input("duration", options=["5s", "10s"]),
|
||||
_ray32_seed_input(),
|
||||
IO.Custom(LumaIO.LUMA_RAY32_KEYFRAME).Input(
|
||||
"keyframes",
|
||||
tooltip="Keyframe sequence from Luma Ray 3.2 Keyframe nodes (at least 2).",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output(), IO.String.Output(display_name="generation_id")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_VIDEO,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
duration: str,
|
||||
seed: int,
|
||||
keyframes: LumaRay32KeyframeChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
|
||||
items = keyframes.items if keyframes is not None else []
|
||||
if len(items) < 2:
|
||||
raise ValueError(
|
||||
"Connect at least 2 Luma Ray 3.2 Keyframe nodes "
|
||||
"(use Luma Ray 3.2 Image to Video for a single start/end frame)."
|
||||
)
|
||||
if len(items) > 64:
|
||||
raise ValueError(f"Ray 3.2 supports at most 64 keyframes; got {len(items)}.")
|
||||
maxframe = 120 if duration == "5s" else 240
|
||||
duration_seconds = maxframe / 24 # 5.0 or 10.0
|
||||
# Resolve each keyframe to an output-frame index, then order by position
|
||||
# (so the user can chain keyframes in any order — the position is what places them)
|
||||
placed: list[tuple[int, torch.Tensor]] = []
|
||||
for item in items:
|
||||
if item.mode == LUMA_KEYFRAME_MODE_SECONDS:
|
||||
if item.value > duration_seconds:
|
||||
raise ValueError(
|
||||
f"Keyframe position {item.value:g}s is past the end of the {duration} video; "
|
||||
f"use 0-{duration_seconds:g}s (or switch the keyframe to fraction mode)."
|
||||
)
|
||||
idx = round(item.value * 24)
|
||||
else:
|
||||
idx = round(item.value * maxframe)
|
||||
placed.append((max(0, min(maxframe, idx)), item.image))
|
||||
placed.sort(key=lambda p: p[0])
|
||||
indexes = [idx for idx, _ in placed]
|
||||
for a, b in zip(indexes, indexes[1:]):
|
||||
if a == b:
|
||||
raise ValueError(
|
||||
f"Two keyframes resolve to the same output frame ({a}) for a {duration} video "
|
||||
f"(valid range 0-{maxframe}); give each keyframe a distinct position."
|
||||
)
|
||||
refs: list[Luma2ImageRef] = []
|
||||
for _, image in placed:
|
||||
url = await upload_image_to_comfyapi(cls, image, mime_type="image/png")
|
||||
refs.append(Luma2ImageRef(url=url))
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model="ray-3.2",
|
||||
type="video",
|
||||
video=Luma2VideoOptions(resolution=resolution, duration=duration, keyframes=refs, keyframe_indexes=indexes),
|
||||
)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32VideoEditNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32VideoEditNode",
|
||||
display_name="Luma Ray 3.2 Video Edit",
|
||||
category="partner/video/Luma",
|
||||
description="Re-render an existing video under a new prompt using Luma Ray 3.2 (restyle, relight, add "
|
||||
"or remove elements) while keeping the original motion. Source video up to 18 seconds; the edited "
|
||||
"video keeps the source's length.",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="Source video to edit. Up to 18 seconds."),
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Describes the desired edit."),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
IO.Combo.Input(
|
||||
"strength",
|
||||
options=[
|
||||
"auto",
|
||||
"adhere_1",
|
||||
"adhere_2",
|
||||
"adhere_3",
|
||||
"flex_1",
|
||||
"flex_2",
|
||||
"flex_3",
|
||||
"reimagine_1",
|
||||
"reimagine_2",
|
||||
"reimagine_3",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="How strongly to preserve vs. reimagine the source. 'auto' lets Ray 3.2 choose; "
|
||||
"adhere_* preserves the most, flex_* is balanced, reimagine_* changes the most.",
|
||||
),
|
||||
_ray32_seed_input(),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
IO.String.Output(display_name="generation_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_EDIT,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, video: Input.Video, prompt: str, resolution: str, strength: str, seed: int
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=6000)
|
||||
try:
|
||||
duration = "5s" if video.get_duration() <= 5.0 else "10s"
|
||||
except Exception:
|
||||
duration = "10s"
|
||||
source_url = await upload_video_to_comfyapi(cls, video, max_duration=18)
|
||||
edit = Luma2VideoEdit(auto_controls=True) if strength == "auto" else Luma2VideoEdit(strength=strength)
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model="ray-3.2",
|
||||
type="video_edit",
|
||||
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
|
||||
video=Luma2VideoOptions(resolution=resolution, duration=duration, edit=edit),
|
||||
)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32VideoReframeNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32VideoReframeNode",
|
||||
display_name="Luma Ray 3.2 Video Reframe",
|
||||
category="partner/video/Luma",
|
||||
description="Change the aspect ratio of an existing video, using Luma Ray 3.2 to fill the newly "
|
||||
"exposed canvas areas. Source video up to 30 seconds. Billed per second of output.",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="Source video to reframe. Up to 30 seconds."),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes how the newly exposed canvas areas should be filled.",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "21:9"]),
|
||||
IO.Combo.Input("resolution", options=["360p", "540p", "720p", "1080p"], default="720p"),
|
||||
_ray32_seed_input(),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
IO.String.Output(display_name="generation_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_REFRAME,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, video: Input.Video, prompt: str, aspect_ratio: str, resolution: str, seed: int
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
|
||||
if resolution == "1080p" and aspect_ratio in {"9:16", "3:4"}:
|
||||
raise ValueError("1080p is not available for vertical aspect ratios (9:16, 3:4) when reframing.")
|
||||
source_url = await upload_video_to_comfyapi(cls, video, max_duration=30)
|
||||
request = Luma2GenerationRequest(
|
||||
prompt=prompt,
|
||||
model="ray-3.2",
|
||||
type="video_reframe",
|
||||
aspect_ratio=aspect_ratio,
|
||||
source=Luma2ImageRef(url=source_url, media_type="video/mp4"),
|
||||
video=Luma2VideoOptions(resolution=resolution),
|
||||
)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaRay32ExtendVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="LumaRay32ExtendVideoNode",
|
||||
display_name="Luma Ray 3.2 Extend Video",
|
||||
category="partner/video/Luma",
|
||||
description="Extend a previous Ray 3.2 generation forward (continue after it) or backward (lead-in "
|
||||
"before it). Connect the generation_id output of a prior Luma Ray 3.2 node."
|
||||
" Extensions are always 5 seconds.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"source_generation_id",
|
||||
default="",
|
||||
tooltip="generation_id of the prior Ray 3.2 video to extend."
|
||||
" Connect the generation_id output of another Luma Ray 3.2 node.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"direction",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Forward (continue after)",
|
||||
[
|
||||
IO.Boolean.Input(
|
||||
"loop",
|
||||
default=False,
|
||||
tooltip="Loop the extended video seamlessly (forward extend only).",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option("Backward (lead-in before)", []),
|
||||
],
|
||||
tooltip="Forward continues after the prior clip; backward is prepended before it.",
|
||||
),
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Text prompt for the new content."),
|
||||
IO.Combo.Input("resolution", options=["540p", "720p", "1080p"], default="720p"),
|
||||
_ray32_seed_input(),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
IO.String.Output(display_name="generation_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=_BADGE_RAY32_VIDEO_5S,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls, source_generation_id: str, direction: dict, prompt: str, resolution: str, seed: int
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=6000)
|
||||
gen_id = (source_generation_id or "").strip()
|
||||
if not gen_id:
|
||||
raise ValueError(
|
||||
"source_generation_id is required (connect the generation_id output of a prior Luma Ray 3.2 node)."
|
||||
)
|
||||
video = Luma2VideoOptions(resolution=resolution, duration="5s")
|
||||
ref = Luma2ImageRef(generation_id=gen_id)
|
||||
if direction["direction"] == "Forward (continue after)":
|
||||
video.start_frame = ref
|
||||
if direction.get("loop"):
|
||||
video.loop = True
|
||||
else:
|
||||
video.end_frame = ref
|
||||
request = Luma2GenerationRequest(prompt=prompt, model="ray-3.2", type="video", video=video)
|
||||
return await _ray32_generate(cls, request)
|
||||
|
||||
|
||||
class LumaExtension(ComfyExtension):
|
||||
@ -944,6 +1481,13 @@ class LumaExtension(ComfyExtension):
|
||||
LumaConceptsNode,
|
||||
LumaImageNode,
|
||||
LumaImageEditNode,
|
||||
LumaRay32TextToVideoNode,
|
||||
LumaRay32ImageToVideoNode,
|
||||
LumaRay32KeyframeNode,
|
||||
LumaRay32KeyframesToVideoNode,
|
||||
LumaRay32VideoEditNode,
|
||||
LumaRay32VideoReframeNode,
|
||||
LumaRay32ExtendVideoNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ from PIL import Image
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.openai import (
|
||||
InputFileContent,
|
||||
@ -62,7 +63,8 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
A torch.Tensor of shape (N, H, W, C) with all returned images; images whose
|
||||
dimensions differ from the first image's are resized to match it.
|
||||
|
||||
Raises:
|
||||
ValueError: If the response is not valid.
|
||||
@ -89,6 +91,14 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
# With size="auto" the API can return images whose dimensions differ by a few pixels within a single response
|
||||
# resize them to the first image's dimensions so they can be stacked into one batch.
|
||||
ref_h, ref_w = image_tensors[0].shape[:2]
|
||||
for i, t in enumerate(image_tensors):
|
||||
if t.shape[:2] != (ref_h, ref_w):
|
||||
samples = t.unsqueeze(0).movedim(-1, 1)
|
||||
samples = common_upscale(samples, ref_w, ref_h, "bilinear", "center")
|
||||
image_tensors[i] = samples.movedim(1, -1).squeeze(0)
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
|
||||
@ -30,13 +30,33 @@ from comfy_api_nodes.apis.runway import (
|
||||
Model4,
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
RunwayAleph2IO,
|
||||
RunwayAleph2KeyframeChain,
|
||||
RunwayAleph2KeyframeItem,
|
||||
RunwayAleph2PromptImageChain,
|
||||
RunwayAleph2PromptImageItem,
|
||||
RunwayAleph2Request,
|
||||
RunwayAleph2Response,
|
||||
RunwayAleph2KeyframeSeconds,
|
||||
RunwayAleph2KeyframeAt,
|
||||
RunwayAleph2PromptImage,
|
||||
RunwayAleph2TimestampPosition,
|
||||
RunwayAleph2RelativePosition,
|
||||
RunwayAleph2ContentModeration,
|
||||
KEYFRAME_MODE_SECONDS,
|
||||
KEYFRAME_MODE_AT,
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP,
|
||||
PROMPT_IMAGE_MODE_POSITION,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
validate_video_duration,
|
||||
upload_images_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
download_url_to_image_tensor,
|
||||
ApiEndpoint,
|
||||
@ -45,6 +65,7 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||
|
||||
@ -53,12 +74,6 @@ AVERAGE_DURATION_FLF_SECONDS = 256
|
||||
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||
|
||||
|
||||
class RunwayApiError(Exception):
|
||||
"""Base exception for Runway API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||
|
||||
@ -84,14 +99,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> float | None:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
@ -102,14 +109,13 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status.value,
|
||||
status_extractor=lambda r: r.status,
|
||||
estimated_duration=estimated_duration,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@ -127,7 +133,7 @@ async def generate_video(
|
||||
|
||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return await download_url_to_video_output(video_url)
|
||||
@ -410,7 +416,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
mime_type="image/png",
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
raise ValueError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
@ -514,11 +520,321 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
||||
raise ValueError("Runway task succeeded but no image data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
_TIMING_ABSOLUTE = "Absolute time (seconds)"
|
||||
_TIMING_FRACTION = "Fraction of duration (0.0-1.0)"
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2KeyframeNode",
|
||||
display_name="Runway Aleph2 Keyframe",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 "
|
||||
"steers the edit at that point of your footage. Connect this to the 'keyframes' input of "
|
||||
"the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'keyframes' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to apply at the chosen moment of the input video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"timing",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the input video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the input video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the input video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Optional earlier keyframes to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
timing: dict,
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain()
|
||||
if timing["timing"] == _TIMING_ABSOLUTE:
|
||||
mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"])
|
||||
else:
|
||||
mode, value = KEYFRAME_MODE_AT, float(timing["fraction"])
|
||||
chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2PromptImageNode",
|
||||
display_name="Runway Aleph2 Prompt Image",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the output (result) video, to guide what "
|
||||
"the edited video looks like at that point. Connect this to the 'prompt_images' input of the "
|
||||
"Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'prompt_images' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to place at the chosen moment of the output video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"position",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the output video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the output video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the output video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Optional earlier prompt images to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
position: dict,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain()
|
||||
if position["position"] == _TIMING_ABSOLUTE:
|
||||
mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"])
|
||||
else:
|
||||
mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"])
|
||||
chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2VideoToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2VideoToVideoNode",
|
||||
display_name="Runway Aleph2 Video to Video",
|
||||
category="partner/video/Runway",
|
||||
description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms "
|
||||
"your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping "
|
||||
"the original motion and timing; the output resolution matches the input video, which must be "
|
||||
"2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to "
|
||||
"the input video) or prompt images (anchored to the output video) - use one or the other, not both.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes what should appear in the output (1-1000 characters).",
|
||||
),
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"public_figure_threshold",
|
||||
options=["auto", "low"],
|
||||
default="low",
|
||||
tooltip="Content moderation for recognizable public figures.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
public_figure_threshold: str = "low",
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=1000)
|
||||
validate_video_duration(
|
||||
video,
|
||||
min_duration=2.0,
|
||||
max_duration=30.0,
|
||||
)
|
||||
try:
|
||||
fps = float(video.get_frame_rate())
|
||||
except Exception:
|
||||
fps = None
|
||||
if fps is not None and fps > 30.0 + 0.01:
|
||||
raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.")
|
||||
|
||||
if (keyframes and keyframes.items) and (prompt_images and prompt_images.items):
|
||||
raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.")
|
||||
|
||||
video_duration: float | None = None
|
||||
try:
|
||||
video_duration = video.get_duration()
|
||||
except Exception:
|
||||
video_duration = None
|
||||
|
||||
def _check_seconds(value: float, label: str) -> None:
|
||||
if video_duration is not None and value > video_duration + 0.0001:
|
||||
raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).")
|
||||
|
||||
video_url = await upload_video_to_comfyapi(cls, video)
|
||||
|
||||
keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = []
|
||||
if keyframes is not None:
|
||||
if len(keyframes.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 keyframes.")
|
||||
for item in keyframes.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
if item.mode == KEYFRAME_MODE_SECONDS:
|
||||
_check_seconds(item.value, "Keyframe timestamp")
|
||||
keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url))
|
||||
else:
|
||||
keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url))
|
||||
|
||||
prompt_image_models: list[RunwayAleph2PromptImage] = []
|
||||
if prompt_images is not None:
|
||||
if len(prompt_images.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 prompt images.")
|
||||
for item in prompt_images.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP:
|
||||
_check_seconds(item.value, "Prompt image timestamp")
|
||||
position = RunwayAleph2TimestampPosition(timestampSeconds=item.value)
|
||||
else:
|
||||
position = RunwayAleph2RelativePosition(positionPercentage=item.value)
|
||||
prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url))
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"),
|
||||
response_model=RunwayAleph2Response,
|
||||
data=RunwayAleph2Request(
|
||||
promptText=prompt,
|
||||
videoUri=video_url,
|
||||
seed=seed,
|
||||
contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold),
|
||||
keyframes=keyframe_models or None,
|
||||
promptImage=prompt_image_models or None,
|
||||
),
|
||||
)
|
||||
|
||||
final_response = await get_response(cls, initial_response.id)
|
||||
if not final_response.output:
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
class RunwayExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -527,6 +843,9 @@ class RunwayExtension(ComfyExtension):
|
||||
RunwayImageToVideoNodeGen3a,
|
||||
RunwayImageToVideoNodeGen4,
|
||||
RunwayTextToImageNode,
|
||||
RunwayAleph2VideoToVideoNode,
|
||||
RunwayAleph2KeyframeNode,
|
||||
RunwayAleph2PromptImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
from comfy_api_nodes.util._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_comfy_api_headers,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
)
|
||||
@ -100,8 +100,7 @@ class SoniloTextToMusic(IO.ComfyNode):
|
||||
node_id="SoniloTextToMusic",
|
||||
display_name="Sonilo Text to Music",
|
||||
category="partner/audio/Sonilo",
|
||||
description="Generate music from a text prompt using Sonilo's AI model. "
|
||||
"Leave duration at 0 to let the model infer it from the prompt.",
|
||||
description="Generate music from a text prompt using Sonilo's AI model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -111,11 +110,10 @@ class SoniloTextToMusic(IO.ComfyNode):
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=0,
|
||||
min=0,
|
||||
default=30,
|
||||
min=1,
|
||||
max=360,
|
||||
tooltip="Target duration in seconds. Set to 0 to let the model "
|
||||
"infer the duration from the prompt. Maximum: 6 minutes.",
|
||||
tooltip="Target duration in seconds. Maximum: 6 minutes.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
@ -136,13 +134,7 @@ class SoniloTextToMusic(IO.ComfyNode):
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration"]),
|
||||
expr="""
|
||||
(
|
||||
widgets.duration > 0
|
||||
? {"type":"usd","usd": 0.005 * widgets.duration}
|
||||
: {"type":"usd","usd": 0.005, "format":{"suffix":"/second"}}
|
||||
)
|
||||
""",
|
||||
expr='{"type":"usd","usd": 0.0025 * widgets.duration}',
|
||||
),
|
||||
)
|
||||
|
||||
@ -150,14 +142,13 @@ class SoniloTextToMusic(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
duration: int = 0,
|
||||
duration: int = 1,
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1, max_length=1000)
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("prompt", prompt)
|
||||
if duration > 0:
|
||||
form.add_field("duration", str(duration))
|
||||
form.add_field("duration", str(duration))
|
||||
audio_bytes = await _stream_sonilo_music(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/sonilo/t2m/generate", method="POST"),
|
||||
@ -174,8 +165,7 @@ async def _stream_sonilo_music(
|
||||
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
headers.update(get_auth_header(cls))
|
||||
headers = get_comfy_api_headers(cls)
|
||||
headers.update(endpoint.headers)
|
||||
|
||||
node_id = get_node_id(cls)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.tripo import (
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoAnimateRigRequest,
|
||||
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.tripo import (
|
||||
TripoFileEmptyReference,
|
||||
TripoFileReference,
|
||||
TripoImageToModelRequest,
|
||||
TripoImportModelRequest,
|
||||
TripoModelVersion,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoOrientation,
|
||||
@ -21,6 +22,7 @@ from comfy_api_nodes.apis.tripo import (
|
||||
TripoTaskType,
|
||||
TripoTextToModelRequest,
|
||||
TripoTextureModelRequest,
|
||||
TripoTexturePrompt,
|
||||
TripoUrlReference,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -28,6 +30,7 @@ from comfy_api_nodes.util import (
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
|
||||
@ -538,6 +541,14 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"texture_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
optional=True,
|
||||
tooltip="Optional text guidance for texturing. Required in practice for imported "
|
||||
"models (Tripo: Import Model), which carry no source image to infer colors from.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -571,6 +582,7 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
texture_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
@ -583,6 +595,7 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
texture_prompt=TripoTexturePrompt(text=texture_prompt.strip()) if texture_prompt.strip() else None,
|
||||
),
|
||||
)
|
||||
return await poll_until_finished(cls, response, average_duration=80)
|
||||
@ -915,6 +928,90 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
return await poll_until_finished(cls, response, average_duration=30)
|
||||
|
||||
|
||||
class TripoImportModelNode(IO.ComfyNode):
|
||||
"""Imports an external 3D model into Tripo, producing a MODEL_TASK_ID for post-processing nodes."""
|
||||
|
||||
SUPPORTED_FORMATS = ("glb", "fbx", "obj", "stl")
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoImportModelNode",
|
||||
display_name="Tripo: Import Model",
|
||||
category="partner/3d/Tripo",
|
||||
description="Import an external 3D model (e.g. from Rodin, Hunyuan3D or a local file) into Tripo "
|
||||
"to use it with Tripo's post-processing nodes: Texture, Rig, Convert. "
|
||||
"GLB is recommended: textures survive import only when embedded in the file. "
|
||||
"Note that texturing an imported model requires a texture prompt.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DGLB, IO.File3DFBX, IO.File3DOBJ, IO.File3DSTL, IO.File3DAny],
|
||||
tooltip="3D model to import (GLB / FBX / OBJ / STL, up to 150 MB). "
|
||||
"OBJ and STL files carry no embedded textures.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"text","text":"Free"}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput:
|
||||
file_format = (model_3d.format or "").lstrip(".").lower()
|
||||
if file_format == "gltf":
|
||||
raise ValueError(
|
||||
"GLTF (.gltf) references external files and cannot be imported. Export a single-file GLB instead."
|
||||
)
|
||||
if file_format not in cls.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported 3D format '{file_format or 'unknown'}'. "
|
||||
f"Tripo import supports: {', '.join(f.upper() for f in cls.SUPPORTED_FORMATS)}."
|
||||
)
|
||||
size = len(model_3d.get_bytes())
|
||||
if size > 150 * 1024 * 1024:
|
||||
raise ValueError(f"Model file is {size / (1024 * 1024):.1f} MB; Tripo import allows up to 150 MB.")
|
||||
|
||||
url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/import", method="POST"),
|
||||
response_model=TripoTaskResponse,
|
||||
data=TripoImportModelRequest(url=url, format=file_format),
|
||||
)
|
||||
if response.code != 0:
|
||||
raise RuntimeError(f"Failed to import model: {response.error}")
|
||||
|
||||
task_id = response.data.task_id
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"),
|
||||
response_model=TripoTaskResponse,
|
||||
failed_statuses=[
|
||||
TripoTaskStatus.FAILED,
|
||||
TripoTaskStatus.CANCELLED,
|
||||
TripoTaskStatus.UNKNOWN,
|
||||
TripoTaskStatus.BANNED,
|
||||
TripoTaskStatus.EXPIRED,
|
||||
],
|
||||
status_extractor=lambda x: x.data.status,
|
||||
progress_extractor=lambda x: x.data.progress,
|
||||
estimated_duration=10,
|
||||
)
|
||||
if response_poll.data.status != TripoTaskStatus.SUCCESS:
|
||||
raise RuntimeError(f"Failed to import model: {response_poll}")
|
||||
return IO.NodeOutput(task_id)
|
||||
|
||||
|
||||
def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str:
|
||||
return (
|
||||
"("
|
||||
@ -1292,6 +1389,7 @@ class TripoExtension(ComfyExtension):
|
||||
TripoP1TextToModelNode,
|
||||
TripoP1ImageToModelNode,
|
||||
TripoP1MultiviewToModelNode,
|
||||
TripoImportModelNode,
|
||||
TripoTextureNode,
|
||||
TripoRefineNode,
|
||||
TripoRigNode,
|
||||
|
||||
@ -48,10 +48,13 @@ from comfy_api_nodes.util import (
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_audio_duration,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
validate_video_duration,
|
||||
)
|
||||
|
||||
|
||||
RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
|
||||
|
||||
|
||||
@ -1657,6 +1660,44 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.1-t2v",
|
||||
[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt describing the elements and visual features. "
|
||||
"Supports English and Chinese.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720P", "1080P"],
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=[
|
||||
"16:9",
|
||||
"9:16",
|
||||
"1:1",
|
||||
"4:3",
|
||||
"3:4",
|
||||
"21:9",
|
||||
"9:21",
|
||||
"5:4",
|
||||
"4:5",
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=3,
|
||||
max=15,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.0-t2v",
|
||||
[
|
||||
@ -1719,7 +1760,9 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
|
||||
$ppsTable := $contains(widgets.model, "1.1")
|
||||
? { "720p": 0.2002, "1080p": 0.2574 }
|
||||
: { "720p": 0.14, "1080p": 0.24 };
|
||||
$pps := $lookup($ppsTable, $res);
|
||||
{ "type": "usd", "usd": $pps * $dur }
|
||||
)
|
||||
@ -1781,6 +1824,30 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.1-i2v",
|
||||
[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt describing the elements and visual features. "
|
||||
"Supports English and Chinese.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720P", "1080P"],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=3,
|
||||
max=15,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.0-i2v",
|
||||
[
|
||||
@ -1843,7 +1910,9 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
|
||||
$ppsTable := $contains(widgets.model, "1.1")
|
||||
? { "720p": 0.2002, "1080p": 0.2574 }
|
||||
: { "720p": 0.14, "1080p": 0.24 };
|
||||
$pps := $lookup($ppsTable, $res);
|
||||
{ "type": "usd", "usd": $pps * $dur }
|
||||
)
|
||||
@ -1859,6 +1928,8 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
|
||||
seed: int,
|
||||
watermark: bool,
|
||||
):
|
||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1), strict=False)
|
||||
media = [
|
||||
Wan27MediaItem(
|
||||
type="first_frame",
|
||||
@ -2053,6 +2124,62 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.1-r2v",
|
||||
[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt describing the video. Use identifiers such as 'character1' and "
|
||||
"'character2' to refer to the reference characters.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["720P", "1080P"],
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"ratio",
|
||||
options=[
|
||||
"16:9",
|
||||
"9:16",
|
||||
"1:1",
|
||||
"4:3",
|
||||
"3:4",
|
||||
"21:9",
|
||||
"9:21",
|
||||
"5:4",
|
||||
"4:5",
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=5,
|
||||
min=3,
|
||||
max=15,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("reference_image"),
|
||||
names=[
|
||||
"image1",
|
||||
"image2",
|
||||
"image3",
|
||||
"image4",
|
||||
"image5",
|
||||
"image6",
|
||||
"image7",
|
||||
"image8",
|
||||
"image9",
|
||||
],
|
||||
min=1,
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"happyhorse-1.0-r2v",
|
||||
[
|
||||
@ -2133,7 +2260,9 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
(
|
||||
$res := $lookup(widgets, "model.resolution");
|
||||
$dur := $lookup(widgets, "model.duration");
|
||||
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
|
||||
$ppsTable := $contains(widgets.model, "1.1")
|
||||
? { "720p": 0.2002, "1080p": 0.2574 }
|
||||
: { "720p": 0.14, "1080p": 0.24 };
|
||||
$pps := $lookup($ppsTable, $res);
|
||||
{ "type": "usd", "usd": $pps * $dur }
|
||||
)
|
||||
@ -2149,8 +2278,11 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
watermark: bool,
|
||||
):
|
||||
validate_string(model["prompt"], strip_whitespace=False, min_length=1)
|
||||
media = []
|
||||
reference_images = model.get("reference_images", {})
|
||||
for key in reference_images:
|
||||
validate_image_dimensions(reference_images[key], min_width=400, min_height=400)
|
||||
validate_image_aspect_ratio(reference_images[key], (1, 2.5), (2.5, 1), strict=False)
|
||||
media = []
|
||||
for key in reference_images:
|
||||
media.append(
|
||||
Wan27MediaItem(
|
||||
@ -2159,7 +2291,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
)
|
||||
)
|
||||
if not media:
|
||||
raise ValueError("At least one reference reference image must be provided.")
|
||||
raise ValueError("At least one reference image must be provided.")
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
|
||||
@ -4,11 +4,14 @@ import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from email.utils import parsedate_to_datetime
|
||||
from io import BytesIO
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
from comfy.model_management import processing_interrupted
|
||||
from comfy_api.latest import IO
|
||||
|
||||
@ -35,6 +38,30 @@ def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
def get_usage_source(node_cls: type[IO.ComfyNode]) -> str:
|
||||
"""Source of the prompt that triggered this API node.
|
||||
|
||||
Defaults to "comfyui-api" when the submitting client didn't identify itself,
|
||||
i.e. a direct API call to this server.
|
||||
"""
|
||||
return node_cls.hidden.comfy_usage_source or "comfyui-api"
|
||||
|
||||
|
||||
def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
"""Common headers (auth, deploy environment, usage source) for Comfy API requests.
|
||||
|
||||
Centralizes the shared header set so every Comfy API request sends a consistent
|
||||
set and new shared headers only need to be added in one place. Intended for
|
||||
relative/cloud URLs resolved against ``default_base_url()``; because the result
|
||||
includes auth, callers must not attach it to arbitrary absolute/presigned URLs.
|
||||
"""
|
||||
return {
|
||||
**get_auth_header(node_cls),
|
||||
"Comfy-Env": get_deploy_environment(),
|
||||
"Comfy-Usage-Source": get_usage_source(node_cls),
|
||||
}
|
||||
|
||||
|
||||
def default_base_url() -> str:
|
||||
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
||||
|
||||
@ -66,6 +93,32 @@ async def sleep_with_interrupt(
|
||||
await asyncio.sleep(min(1.0, end - now))
|
||||
|
||||
|
||||
def _retry_after_wait(value: str | None, fallback: float, max_wait: float) -> float:
|
||||
"""Delay before the next retry, honoring a server ``Retry-After`` header."""
|
||||
|
||||
seconds: float | None = None
|
||||
if value is not None:
|
||||
value = value.strip()
|
||||
if value.isascii() and value.isdigit():
|
||||
# delay-seconds form. The ASCII-digit guard keeps exotic Unicode "digit" characters away from float()
|
||||
# an all-digit string always converts (huge values become inf, never raising).
|
||||
seconds = float(value)
|
||||
elif value:
|
||||
# HTTP-date form. parsedate_to_datetime raises OverflowError (not a ValueError) on absurd years/offsets
|
||||
try:
|
||||
parsed = parsedate_to_datetime(value)
|
||||
except (TypeError, ValueError, OverflowError):
|
||||
parsed = None
|
||||
if parsed is not None:
|
||||
if parsed.tzinfo is None: # naive datetime: HTTP-date is UTC
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
delta = (parsed - datetime.now(timezone.utc)).total_seconds()
|
||||
seconds = delta if delta > 0 else 0.0
|
||||
if seconds is None:
|
||||
return fallback
|
||||
return min(seconds, max_wait)
|
||||
|
||||
|
||||
def mimetype_to_extension(mime_type: str) -> str:
|
||||
"""Converts a MIME type to a file extension."""
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
@ -19,12 +19,11 @@ from comfy import utils
|
||||
from comfy_api.latest import IO
|
||||
from server import PromptServer
|
||||
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
_retry_after_wait,
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_comfy_api_headers,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
@ -84,6 +83,7 @@ class _PollUIState:
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
_MAX_RETRY_AFTER_WAIT = 150.0 # Cap a server Retry-After at this many seconds so a large hint can't block execution
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
|
||||
@ -645,8 +645,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
payload_headers["Comfy-Env"] = get_deploy_environment()
|
||||
payload_headers.update(get_comfy_api_headers(cfg.node_cls))
|
||||
if cfg.endpoint.headers:
|
||||
payload_headers.update(cfg.endpoint.headers)
|
||||
|
||||
@ -750,6 +749,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
should_retry = True
|
||||
|
||||
if should_retry:
|
||||
wait_time = _retry_after_wait(resp.headers.get("Retry-After"), wait_time, _MAX_RETRY_AFTER_WAIT)
|
||||
logging.warning(
|
||||
"HTTP %s %s -> %s. Waiting %.2fs (%s).",
|
||||
method,
|
||||
|
||||
@ -17,7 +17,7 @@ from folder_paths import get_output_directory
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_comfy_api_headers,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
to_aiohttp_url,
|
||||
@ -64,7 +64,7 @@ async def download_url_to_bytesio(
|
||||
if cls is None:
|
||||
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
headers = get_auth_header(cls)
|
||||
headers = get_comfy_api_headers(cls)
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
|
||||
66
comfy_execution/asset_enrichment.py
Normal file
66
comfy_execution/asset_enrichment.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Enrich executed-node output entries with asset id."""
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
def enrich_output_with_assets(output_ui: dict) -> dict:
|
||||
"""Register file-type output entries as assets and inject their ``id``.
|
||||
|
||||
Runs at output-processing time, once per produced output, when
|
||||
--enable-assets is set. Returns a new dict; entries without a resolvable
|
||||
on-disk file path are left unchanged. Errors are caught per-entry so a
|
||||
failure never blocks execution or the other entries.
|
||||
"""
|
||||
from comfy.cli_args import args
|
||||
if not args.enable_assets:
|
||||
return output_ui
|
||||
|
||||
import folder_paths
|
||||
from app.assets.services.ingest import register_file_in_place, DependencyMissingError
|
||||
|
||||
enriched = {}
|
||||
for key, entries in output_ui.items():
|
||||
if not isinstance(entries, list):
|
||||
enriched[key] = entries
|
||||
continue
|
||||
new_entries = []
|
||||
for entry in entries:
|
||||
if not isinstance(entry, dict) or "filename" not in entry or "type" not in entry:
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
try:
|
||||
base = folder_paths.get_directory_by_type(entry["type"])
|
||||
if base is None:
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
base_abs = os.path.abspath(base)
|
||||
abs_path = os.path.abspath(os.path.join(base_abs, entry.get("subfolder") or "", entry["filename"]))
|
||||
try:
|
||||
if os.path.commonpath([base_abs, abs_path]) != base_abs:
|
||||
raise ValueError("escapes base")
|
||||
except ValueError:
|
||||
logging.warning("Asset enrichment skipped (path escapes base): %s", entry.get("filename"))
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
if not os.path.isfile(abs_path):
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
|
||||
# Register unconditionally: the file was just produced, and
|
||||
# register_file_in_place re-hashes so an overwritten path can
|
||||
# never carry a stale id.
|
||||
result = register_file_in_place(
|
||||
abs_path=abs_path,
|
||||
name=entry["filename"],
|
||||
tags=[entry["type"]],
|
||||
)
|
||||
|
||||
entry = dict(entry)
|
||||
entry["id"] = result.ref.id
|
||||
except DependencyMissingError:
|
||||
logging.warning("Asset enrichment skipped (blake3 not available): %s", entry.get("filename"))
|
||||
except Exception:
|
||||
logging.warning("Failed to enrich output entry with asset id: %s", entry.get("filename"), exc_info=True)
|
||||
new_entries.append(entry)
|
||||
enriched[key] = new_entries
|
||||
return enriched
|
||||
@ -3,11 +3,23 @@ Job utilities for the /api/jobs endpoint.
|
||||
Provides normalization and helper functions for job status tracking.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from typing import Callable, Optional
|
||||
|
||||
from comfy_api.internal import prune_dict
|
||||
|
||||
|
||||
# Result of classifying a job for cancellation.
|
||||
# 'running' -> job is currently executing (interrupt it)
|
||||
# 'pending' -> job is queued but not started (dequeue it)
|
||||
# 'terminal' -> job already finished (present in history); cancel is a no-op
|
||||
# 'unknown' -> job id is not present anywhere
|
||||
CANCEL_RUNNING = 'running'
|
||||
CANCEL_PENDING = 'pending'
|
||||
CANCEL_TERMINAL = 'terminal'
|
||||
CANCEL_UNKNOWN = 'unknown'
|
||||
|
||||
|
||||
class JobStatus:
|
||||
"""Job status constants."""
|
||||
PENDING = 'pending'
|
||||
@ -19,6 +31,25 @@ class JobStatus:
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||
|
||||
|
||||
def validate_job_id(value) -> str:
|
||||
"""Validate a client-supplied job (prompt) id.
|
||||
|
||||
Job ids must be UUIDs in the canonical lowercase hyphenated form. The id
|
||||
is stored and compared verbatim everywhere downstream — history keys,
|
||||
websocket events, and /interrupt matching — so accepting another spelling
|
||||
would silently rewrite the client's id and then miss every exact-match
|
||||
lookup. Rejecting loudly beats that.
|
||||
|
||||
Returns the id unchanged. Raises ValueError when the value is not a
|
||||
string in canonical UUID form.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"job id must be a string, got {type(value).__name__}")
|
||||
if str(uuid.UUID(value)) != value:
|
||||
raise ValueError("job id must be a UUID in canonical lowercase hyphenated form")
|
||||
return value
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||
|
||||
@ -387,3 +418,71 @@ def get_all_jobs(
|
||||
jobs = jobs[:limit]
|
||||
|
||||
return (jobs, total_count)
|
||||
|
||||
|
||||
def classify_job_for_cancel(prompt_id: str, running: list, queued: list, history: dict) -> str:
|
||||
"""Classify a job id for cancellation.
|
||||
|
||||
Returns one of CANCEL_RUNNING, CANCEL_PENDING, CANCEL_TERMINAL, CANCEL_UNKNOWN.
|
||||
|
||||
Queue items are tuples whose second element (index 1) is the prompt_id.
|
||||
History is a dict keyed by prompt_id, so a job present there has already
|
||||
finished and cancelling it is a no-op.
|
||||
"""
|
||||
for item in running:
|
||||
if item[1] == prompt_id:
|
||||
return CANCEL_RUNNING
|
||||
for item in queued:
|
||||
if item[1] == prompt_id:
|
||||
return CANCEL_PENDING
|
||||
if prompt_id in history:
|
||||
return CANCEL_TERMINAL
|
||||
return CANCEL_UNKNOWN
|
||||
|
||||
|
||||
def cancel_job(
|
||||
prompt_id: str,
|
||||
running: list,
|
||||
queued: list,
|
||||
history: dict,
|
||||
interrupt: Callable[[str], bool],
|
||||
dequeue: Callable[[str], bool],
|
||||
) -> str:
|
||||
"""Cancel a single job by id, regardless of state.
|
||||
|
||||
Maps the cancel onto the runtime's existing mechanics:
|
||||
- a running job is interrupted via ``interrupt``
|
||||
- a pending job is removed from the queue via ``dequeue``
|
||||
- a job that already finished (terminal) is a no-op
|
||||
- an unknown id is a no-op (callers that need fail-fast behaviour should
|
||||
validate ids up front with ``classify_job_for_cancel``)
|
||||
|
||||
Both ``interrupt`` and ``dequeue`` take the prompt id and return whether
|
||||
they acted on a job that was *actually* in that state, so the value returned
|
||||
here reflects what truly happened rather than the (possibly stale)
|
||||
classification. This matters around the narrow TOCTOU windows where a job
|
||||
changes state between the caller's snapshot and the action:
|
||||
|
||||
- a job classified RUNNING may have finished before ``interrupt`` fires:
|
||||
``interrupt`` returns False and this returns CANCEL_UNKNOWN (no-op).
|
||||
- a job classified PENDING may have started executing before ``dequeue``
|
||||
fires: ``dequeue`` returns False, ``interrupt`` then catches the now-
|
||||
running job and this returns CANCEL_RUNNING. If it had simply finished
|
||||
instead, both return False and this returns CANCEL_UNKNOWN.
|
||||
|
||||
``interrupt`` must be atomic — interrupt the job only if it is still the one
|
||||
running — so a cancel can never land on an unrelated prompt that started in
|
||||
the meantime (see ``execution.PromptQueue.interrupt_if_running``).
|
||||
"""
|
||||
classification = classify_job_for_cancel(prompt_id, running, queued, history)
|
||||
if classification == CANCEL_RUNNING:
|
||||
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
|
||||
if classification == CANCEL_PENDING:
|
||||
if dequeue(prompt_id):
|
||||
return CANCEL_PENDING
|
||||
# Left the pending queue between classification and dequeue: if it
|
||||
# started executing, interrupt the now-running job; otherwise it has
|
||||
# already finished and the cancel is a genuine no-op.
|
||||
return CANCEL_RUNNING if interrupt(prompt_id) else CANCEL_UNKNOWN
|
||||
# CANCEL_TERMINAL and CANCEL_UNKNOWN are intentional no-ops.
|
||||
return classification
|
||||
|
||||
@ -11,7 +11,7 @@ class TextEncodeAceStepAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TextEncodeAceStepAudio",
|
||||
category="model/conditioning",
|
||||
category="model/conditioning/ace",
|
||||
inputs=[
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
@ -33,7 +33,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TextEncodeAceStepAudio1.5",
|
||||
category="model/conditioning",
|
||||
category="model/conditioning/ace",
|
||||
inputs=[
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
@ -67,7 +67,7 @@ class EmptyAceStepLatentAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
display_name="Empty Ace Step 1.0 Latent Audio",
|
||||
category="model/latent/audio",
|
||||
category="model/latent/ace",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
@ -90,7 +90,7 @@ class EmptyAceStep15LatentAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAceStep1.5LatentAudio",
|
||||
display_name="Empty Ace Step 1.5 Latent Audio",
|
||||
category="model/latent/audio",
|
||||
category="model/latent/ace",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
IO.Int.Input(
|
||||
@ -111,8 +111,8 @@ class ReferenceAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReferenceTimbreAudio",
|
||||
display_name="Reference Audio",
|
||||
category="advanced/conditioning/audio",
|
||||
display_name="Set Reference Audio",
|
||||
category="model/conditioning",
|
||||
is_experimental=True,
|
||||
description="This node sets the reference audio for ace step 1.5",
|
||||
inputs=[
|
||||
|
||||
@ -16,7 +16,7 @@ class APG(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="APG",
|
||||
display_name="Adaptive Projected Guidance",
|
||||
category="model/sampling/custom_sampling",
|
||||
category="model/sampling/custom",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
|
||||
@ -19,7 +19,7 @@ class EmptyARVideoLatent(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyARVideoLatent",
|
||||
category="model/latent/video",
|
||||
category="model/latent/autoregressive",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
@ -85,7 +85,7 @@ class ARVideoI2V(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ARVideoI2V",
|
||||
category="model/conditioning/video_models",
|
||||
category="model/conditioning/autoregressive",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Vae.Input("vae"),
|
||||
|
||||
@ -16,7 +16,7 @@ class EmptyLatentAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentAudio",
|
||||
display_name="Empty Latent Audio",
|
||||
category="model/latent/audio",
|
||||
category="model/latent",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||
@ -41,7 +41,7 @@ class ConditioningStableAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ConditioningStableAudio",
|
||||
category="model/conditioning",
|
||||
category="model/conditioning/stable audio",
|
||||
inputs=[
|
||||
IO.Conditioning.Input("positive"),
|
||||
IO.Conditioning.Input("negative"),
|
||||
@ -70,7 +70,7 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
node_id="VAEEncodeAudio",
|
||||
search_aliases=["audio to latent"],
|
||||
display_name="VAE Encode Audio",
|
||||
category="model/latent/audio",
|
||||
category="model/latent",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Vae.Input("vae"),
|
||||
@ -115,7 +115,7 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
node_id="VAEDecodeAudio",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio",
|
||||
category="model/latent/audio",
|
||||
category="model/latent",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
@ -137,7 +137,7 @@ class VAEDecodeAudioTiled(IO.ComfyNode):
|
||||
node_id="VAEDecodeAudioTiled",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio (Tiled)",
|
||||
category="model/latent/audio",
|
||||
category="model/latent",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudio",
|
||||
search_aliases=["export flac"],
|
||||
display_name="Save Audio (FLAC)",
|
||||
display_name="Save Audio (FLAC) (DEPRECATED)",
|
||||
category="audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
@ -166,7 +166,9 @@ class SaveAudio(IO.ComfyNode):
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_deprecated=True,
|
||||
is_output_node=True,
|
||||
outputs=[IO.Audio.Output("audio")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -174,11 +176,10 @@ class SaveAudio(IO.ComfyNode):
|
||||
if audio is None:
|
||||
raise ValueError("SaveAudio: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(
|
||||
audio,
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
|
||||
)
|
||||
|
||||
save_flac = execute # TODO: remove
|
||||
|
||||
|
||||
class SaveAudioMP3(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -186,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioMP3",
|
||||
search_aliases=["export mp3"],
|
||||
display_name="Save Audio (MP3)",
|
||||
display_name="Save Audio (MP3) (DEPRECATED)",
|
||||
category="audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
@ -195,7 +196,9 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_deprecated=True,
|
||||
is_output_node=True,
|
||||
outputs=[IO.Audio.Output("audio")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -203,13 +206,12 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
if audio is None:
|
||||
raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(
|
||||
audio,
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
)
|
||||
)
|
||||
|
||||
save_mp3 = execute # TODO: remove
|
||||
|
||||
|
||||
class SaveAudioOpus(IO.ComfyNode):
|
||||
@classmethod
|
||||
@ -217,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioOpus",
|
||||
search_aliases=["export opus"],
|
||||
display_name="Save Audio (Opus)",
|
||||
display_name="Save Audio (Opus) (DEPRECATED)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
@ -225,7 +227,9 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_deprecated=True,
|
||||
is_output_node=True,
|
||||
outputs=[IO.Audio.Output("audio")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -233,12 +237,57 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
if audio is None:
|
||||
raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(
|
||||
audio,
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
)
|
||||
)
|
||||
|
||||
save_opus = execute # TODO: remove
|
||||
|
||||
class SaveAudioAdvanced(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioAdvanced",
|
||||
search_aliases=["save audio", "export audio", "output audio", "write audio", "flac", "mp3", "opus"],
|
||||
display_name="Save Audio (Advanced)",
|
||||
description="Saves the input audio to your ComfyUI output directory.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio", tooltip="The audio to save."),
|
||||
IO.String.Input(
|
||||
"filename_prefix",
|
||||
default="audio/ComfyUI",
|
||||
tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd%."),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"format",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("flac", []),
|
||||
IO.DynamicCombo.Option("mp3", [
|
||||
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
||||
]),
|
||||
IO.DynamicCombo.Option("opus", [
|
||||
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
||||
]),
|
||||
],
|
||||
tooltip="The file format in which to save the audio.",
|
||||
),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
outputs=[IO.Audio.Output("audio")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, filename_prefix: str, format: dict) -> IO.NodeOutput:
|
||||
file_format = format.get("format", None)
|
||||
quality = format.get("quality", None)
|
||||
if quality:
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality)
|
||||
else:
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format)
|
||||
return IO.NodeOutput(audio, ui=ui)
|
||||
|
||||
|
||||
class PreviewAudio(IO.ComfyNode):
|
||||
@ -254,13 +303,14 @@ class PreviewAudio(IO.ComfyNode):
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
outputs=[IO.Audio.Output("audio")]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
if audio is None:
|
||||
raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).")
|
||||
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
|
||||
return IO.NodeOutput(audio, ui=UI.PreviewAudio(audio, cls=cls))
|
||||
|
||||
save_flac = execute # TODO: remove
|
||||
|
||||
@ -822,6 +872,7 @@ class AudioExtension(ComfyExtension):
|
||||
SaveAudio,
|
||||
SaveAudioMP3,
|
||||
SaveAudioOpus,
|
||||
SaveAudioAdvanced,
|
||||
LoadAudio,
|
||||
PreviewAudio,
|
||||
ConditioningStableAudio,
|
||||
|
||||
108
comfy_extras/nodes_bernini.py
Normal file
108
comfy_extras/nodes_bernini.py
Normal file
@ -0,0 +1,108 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def _resize_long_edge(image, max_size, stride=16):
|
||||
"""Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride`"""
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
scale = min(max_size / max(h, w), 1.0)
|
||||
nh = max(stride, round(h * scale / stride) * stride)
|
||||
nw = max(stride, round(w * scale / stride) * stride)
|
||||
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1)
|
||||
|
||||
|
||||
class BerniniConditioning(io.ComfyNode):
|
||||
"""Bernini in-context conditioning for a Wan2.2-A14B model.
|
||||
|
||||
Attaches the VAE-encoded source video / reference images to the conditioning
|
||||
source video first, then each reference image
|
||||
|
||||
The task is inferred from which inputs are connected:
|
||||
(nothing) -> t2v (text-to-video)
|
||||
source_video -> v2v (video-to-video)
|
||||
source_video + ref_images -> rv2v (reference-guided video editing)
|
||||
ref_images only -> r2v (reference-to-video)
|
||||
source_video + ref_video -> ads2v (insert image/video into video)
|
||||
|
||||
source_video is the edit base / canvas (resized to width x height).
|
||||
reference_video is moving content to composite in.
|
||||
Streams are ordered source_video, reference_video, then reference_images -> source_id (1, 2, 3, ...).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="BerniniConditioning",
|
||||
display_name="Bernini Conditioning",
|
||||
category="model/conditioning/bernini",
|
||||
description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)."
|
||||
"Reference images injected as in-context tokens (r2v, rv2v) are encoded independently at their own native aspect ratio (long edge capped at ref_max_size)",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=8192, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("source_video", optional=True, tooltip=("Source video to edit or restyle (v2v, rv2v). Resized to width/height and trimmed to length.")),
|
||||
io.Image.Input("reference_video", optional=True, tooltip=("Video to insert into the source video (ads2v).")),
|
||||
io.Autogrow.Input("reference_images", optional=True,
|
||||
template=io.Autogrow.TemplatePrefix(
|
||||
input=io.Image.Input("reference_image", tooltip=("Reference image injected as an in-context token (r2v, rv2v).")),
|
||||
prefix="reference_image_", min=0, max=8)),
|
||||
io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=(
|
||||
"Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, source_video=None, reference_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
|
||||
# source_video (1), reference_video (2), reference_images (3, 4, ...).
|
||||
context = []
|
||||
if source_video is not None:
|
||||
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
|
||||
context.append(vae.encode(vid[:, :, :, :3]))
|
||||
|
||||
if reference_video is not None:
|
||||
ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect
|
||||
context.append(vae.encode(ref_vid[:, :, :, :3]))
|
||||
|
||||
# reference_images is an autogrow dict {reference_image_0: IMAGE, ...}; each slot is a
|
||||
# separate stream at its own native aspect (a multi-image batch in one slot -> one stream per frame).
|
||||
if reference_images:
|
||||
for name in sorted(reference_images):
|
||||
imgs = reference_images[name]
|
||||
if imgs is None:
|
||||
continue
|
||||
for i in range(imgs.shape[0]):
|
||||
img = _resize_long_edge(imgs[i:i + 1], ref_max_size) # native aspect per ref
|
||||
context.append(vae.encode(img[:, :, :, :3]))
|
||||
|
||||
if context:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"context_latents": context})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"context_latents": context})
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent})
|
||||
|
||||
|
||||
class BerniniExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [BerniniConditioning,]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BerniniExtension:
|
||||
return BerniniExtension()
|
||||
@ -36,15 +36,15 @@ class RemoveBackground(IO.ComfyNode):
|
||||
category="image/background removal",
|
||||
description="Generates a foreground mask to remove the background from an image using a background removal model.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Input image to remove the background from"),
|
||||
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
|
||||
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask"),
|
||||
IO.Image.Input("image", tooltip="Input image to remove the background from")
|
||||
],
|
||||
outputs=[
|
||||
IO.Mask.Output("mask", tooltip="Generated foreground mask")
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, image, bg_removal_model):
|
||||
def execute(cls, bg_removal_model, image):
|
||||
mask = bg_removal_model.encode_image(image)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user