diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt index 2cbb00d99..2c72c8a13 100755 --- a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -1,5 +1,4 @@ -As of the time of writing this you need this driver for best results: -https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html +As of the time of writing this you need a recent driver. Updating to the latest driver is recommended. HOW TO RUN: @@ -7,9 +6,9 @@ If you have a AMD gpu: run_amd_gpu.bat -If you have memory issues you can try disabling the smart memory management by running comfyui with: +If you have memory issues you can try enabling the new dynamic memory management by running comfyui with: -run_amd_gpu_disable_smart_memory.bat +run_amd_gpu_enable_dynamic_vram.bat IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints diff --git a/.github/workflows/check-line-endings.yml b/.github/workflows/check-line-endings.yml index eeb594d6c..a69a24a87 100644 --- a/.github/workflows/check-line-endings.yml +++ b/.github/workflows/check-line-endings.yml @@ -17,7 +17,7 @@ jobs: - name: Check for Windows line endings (CRLF) run: | # Get the list of changed files in the PR - CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }}) + CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci') # Flag to track if CRLF is found CRLF_FOUND=false diff --git a/README.md b/README.md index dc2389266..786a14166 100644 --- a/README.md +++ b/README.md @@ -364,7 +364,7 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step | Flag | Description | |------|-------------| | `--enable-manager` | Enable ComfyUI-Manager | -| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | +| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (implies `--enable-manager`) | | `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | @@ -462,16 +462,6 @@ To use the most up-to-date frontend version: This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes. -### Accessing the Legacy Frontend - -If you need to use the legacy frontend for any reason, you can access it using the following command line argument: - -``` ---front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest -``` - -This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend). - # QA ### Which GPU should I buy for this? diff --git a/alembic_db/versions/0004_drop_tag_type.py b/alembic_db/versions/0004_drop_tag_type.py new file mode 100644 index 000000000..582bec4e8 --- /dev/null +++ b/alembic_db/versions/0004_drop_tag_type.py @@ -0,0 +1,39 @@ +""" +Drop the vestigial tags.tag_type column. + +tag_type was always "user" in practice — no code path ever set it to anything +else (no system/seeded classification was ever wired up) and nothing queried it. +The column, its index (ix_tags_tag_type), and the corresponding API field were +dead weight, so they are removed. + +Revision ID: 0004_drop_tag_type +Revises: 0003_add_metadata_job_id +Create Date: 2026-06-03 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0004_drop_tag_type" +down_revision = "0003_add_metadata_job_id" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("tags") as batch_op: + batch_op.drop_index("ix_tags_tag_type") + batch_op.drop_column("tag_type") + + +def downgrade() -> None: + with op.batch_alter_table("tags") as batch_op: + batch_op.add_column( + sa.Column( + "tag_type", + sa.String(length=32), + nullable=False, + server_default="user", + ) + ) + batch_op.create_index("ix_tags_tag_type", ["tag_type"]) diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 6555974e9..7ef462f5c 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -39,6 +39,7 @@ from app.assets.services import ( update_asset_metadata, upload_from_temp_path, ) +from app.assets.services.cursor import InvalidCursorError from app.assets.services.tagging import list_tag_histogram ROUTES = web.RouteTableDef() @@ -174,7 +175,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu user_metadata=result.ref.user_metadata or {}, metadata=result.ref.system_metadata, job_id=result.ref.job_id, - prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat + prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility created_at=result.ref.created_at, updated_at=result.ref.updated_at, last_access_time=result.ref.last_access_time, @@ -211,24 +212,37 @@ async def list_assets_route(request: web.Request) -> web.Response: order_candidate = (q.order or "desc").lower() order = order_candidate if order_candidate in {"asc", "desc"} else "desc" - result = list_assets_page( - owner_id=USER_MANAGER.get_request_user_id(request), - include_tags=q.include_tags, - exclude_tags=q.exclude_tags, - name_contains=q.name_contains, - metadata_filter=q.metadata_filter, - limit=q.limit, - offset=q.offset, - sort=sort, - order=order, - ) + try: + result = list_assets_page( + owner_id=USER_MANAGER.get_request_user_id(request), + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + offset=q.offset, + sort=sort, + order=order, + after=q.after, + ) + except InvalidCursorError as e: + return _build_error_response(400, "INVALID_CURSOR", str(e)) summaries = [_build_asset_response(item) for item in result.items] + # has_more semantics differ by mode: + # - cursor mode: a non-empty next_cursor means there are more results. + # - offset mode: derived from total - (offset + page size). + if q.after is not None: + has_more = result.next_cursor is not None + else: + has_more = (q.offset + len(summaries)) < result.total + payload = schemas_out.AssetsList( assets=summaries, total=result.total, - has_more=(q.offset + len(summaries)) < result.total, + has_more=has_more, + next_cursor=result.next_cursor, ) return web.json_response(payload.model_dump(mode="json", exclude_none=True)) @@ -519,18 +533,14 @@ async def update_asset_route(request: web.Request) -> web.Response: @_require_assets_feature_enabled async def delete_asset_route(request: web.Request) -> web.Response: reference_id = str(uuid.UUID(request.match_info["id"])) - delete_content_param = request.query.get("delete_content") - delete_content = ( - False - if delete_content_param is None - else delete_content_param.lower() not in {"0", "false", "no"} - ) try: + # Deleting an asset is a soft delete of the reference; the underlying + # content is preserved (it may be shared with other references). deleted = delete_asset_reference( reference_id=reference_id, owner_id=USER_MANAGER.get_request_user_id(request), - delete_content_if_orphan=delete_content, + delete_content_if_orphan=False, ) except Exception: logging.exception( @@ -575,8 +585,8 @@ async def get_tags(request: web.Request) -> web.Response: ) tags = [ - schemas_out.TagUsage(name=name, count=count, type=tag_type) - for (name, tag_type, count) in rows + schemas_out.TagUsage(name=name, count=count) + for (name, count) in rows ] payload = schemas_out.TagsList( tags=tags, total=total, has_more=(query.offset + len(tags)) < total diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 186a6ae1e..af666746d 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -59,6 +59,11 @@ class ListAssetsQuery(BaseModel): limit: conint(ge=1, le=500) = 20 offset: conint(ge=0) = 0 + # Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination + # is supported for sort values `created_at`, `updated_at`, `name`, `size`. + # Supplying `after` together with `sort=last_access_time` returns + # 400 INVALID_CURSOR; that sort only supports offset/limit. + after: str | None = None sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = ( "created_at" diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index 0e748b907..4e38e19d1 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -41,12 +41,13 @@ class AssetsList(BaseModel): assets: list[Asset] total: int has_more: bool + # Opaque cursor for the next page. Omitted when there are no more results. + next_cursor: str | None = None class TagUsage(BaseModel): name: str count: int - type: str class TagsList(BaseModel): diff --git a/app/assets/database/models.py b/app/assets/database/models.py index a3af8a192..9b61d309a 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -227,7 +227,6 @@ class Tag(Base): __tablename__ = "tags" name: Mapped[str] = mapped_column(String(512), primary_key=True) - tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship( back_populates="tag", @@ -240,7 +239,5 @@ class Tag(Base): overlaps="asset_reference_links,tag_links,tags,asset_reference", ) - __table_args__ = (Index("ix_tags_tag_type", "tag_type"),) - def __repr__(self) -> str: return f"" diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 8b90ae511..792411800 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -266,9 +266,18 @@ def list_references_page( metadata_filter: dict | None = None, sort: str | None = None, order: str | None = None, + after_cursor_value: object | None = None, + after_cursor_id: str | None = None, ) -> tuple[list[AssetReference], dict[str, list[str]], int]: """List references with pagination, filtering, and sorting. + When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses + keyset pagination — ``offset`` is ignored and a WHERE clause selects rows + strictly after the given ``(sort_col, id)`` position in the active sort + direction. The cursor value must already be typed for the column + (datetime for time sorts, int for size, str for name); the caller decodes + the opaque cursor string and resolves to the typed value. + Returns (references, tag_map, total_count). """ base = ( @@ -297,9 +306,31 @@ def list_references_page( "size": Asset.size_bytes, } sort_col = sort_map.get(sort, AssetReference.created_at) - sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + descending = order == "desc" - base = base.order_by(sort_exp).limit(limit).offset(offset) + # Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor. + # Equivalent to: sort_col v OR (sort_col = v AND id cursor_id). + if after_cursor_value is not None and after_cursor_id is not None: + if descending: + keyset = sa.or_( + sort_col < after_cursor_value, + sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id), + ) + else: + keyset = sa.or_( + sort_col > after_cursor_value, + sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id), + ) + base = base.where(keyset) + + # Secondary ORDER BY id (matching the primary direction) gives the keyset + # comparison a deterministic tiebreaker on duplicate sort_col values. + id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc() + sort_exp = sort_col.desc() if descending else sort_col.asc() + + base = base.order_by(sort_exp, id_exp).limit(limit) + if after_cursor_id is None: + base = base.offset(offset) count_stmt = ( select(sa.func.count()) diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index f4126dba8..d41d73a10 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -55,13 +55,11 @@ def validate_tags_exist(session: Session, tags: list[str]) -> None: raise ValueError(f"Unknown tags: {missing}") -def ensure_tags_exist( - session: Session, names: Iterable[str], tag_type: str = "user" -) -> None: +def ensure_tags_exist(session: Session, names: Iterable[str]) -> None: wanted = normalize_tags(list(names)) if not wanted: return - rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + rows = [{"name": n} for n in list(dict.fromkeys(wanted))] ins = ( sqlite.insert(Tag) .values(rows) @@ -97,7 +95,7 @@ def set_reference_tags( to_remove = [t for t in current if t not in desired] if to_add: - ensure_tags_exist(session, to_add, tag_type="user") + ensure_tags_exist(session, to_add) session.add_all( [ AssetReferenceTag( @@ -142,7 +140,7 @@ def add_tags_to_reference( return AddTagsResult(added=[], already_present=[], total_tags=total) if create_if_missing: - ensure_tags_exist(session, norm, tag_type="user") + ensure_tags_exist(session, norm) current = set(get_reference_tags(session, reference_id)) @@ -289,7 +287,6 @@ def list_tags_with_usage( q = ( select( Tag.name, - Tag.tag_type, func.coalesce(counts_sq.c.cnt, 0).label("count"), ) .select_from(Tag) @@ -331,7 +328,7 @@ def list_tags_with_usage( rows = (session.execute(q.limit(limit).offset(offset))).all() total = (session.execute(total_q)).scalar_one() - rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + rows_norm = [(name, int(count or 0)) for (name, count) in rows] return rows_norm, int(total or 0) diff --git a/app/assets/scanner.py b/app/assets/scanner.py index ebb6869af..2c1e97840 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -33,6 +33,7 @@ from app.assets.services.file_utils import ( verify_file_unchanged, ) from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash +from app.assets.services.image_dimensions import extract_image_dimensions from app.assets.services.metadata_extract import extract_file_metadata from app.assets.services.path_utils import ( compute_relative_filename, @@ -354,7 +355,7 @@ def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: return 0 with create_session() as sess: if tag_pool: - ensure_tags_exist(sess, tag_pool, tag_type="user") + ensure_tags_exist(sess, tag_pool) result = batch_insert_seed_assets(sess, specs=specs, owner_id="") sess.commit() return result.inserted_refs @@ -506,6 +507,10 @@ def enrich_asset( if extract_metadata and metadata: system_metadata = metadata.to_user_metadata() + if mime_type and mime_type.startswith("image/"): + dims = extract_image_dimensions(file_path, mime_type=mime_type) + if dims: + system_metadata.update(dims) set_reference_system_metadata(session, reference_id, system_metadata) if full_hash: diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 5aefd9956..d4e4fc61c 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -1,8 +1,19 @@ import contextlib import mimetypes import os +from datetime import timezone from typing import Sequence +from app.assets.services.cursor import ( + CursorPayload, + InvalidCursorError, + decode_cursor, + decode_cursor_int, + decode_cursor_time, + encode_cursor, + encode_cursor_from_time, +) + from app.assets.database.models import Asset from app.assets.database.queries import ( @@ -149,6 +160,16 @@ def delete_asset_reference( owner_id: str, delete_content_if_orphan: bool = True, ) -> bool: + """Delete an asset reference. + + With ``delete_content_if_orphan=False`` (a soft delete), the reference is + hidden and the underlying content is preserved. With ``True``, the content + is also removed once it becomes orphaned. + + Note: the public DELETE /api/assets/{id} endpoint always soft-deletes + (passes ``False``); the orphan-reclamation path is intentionally + internal-only, retained for a future GC/admin caller. + """ with create_session() as session: if not delete_content_if_orphan: # Soft delete: mark the reference as deleted but keep everything @@ -242,6 +263,11 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None: return extract_asset_data(asset) +# Sort fields that support cursor pagination. `last_access_time` is not +# in this list — it falls back to offset/limit. +_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size") + + def list_assets_page( owner_id: str = "", include_tags: Sequence[str] | None = None, @@ -252,7 +278,39 @@ def list_assets_page( offset: int = 0, sort: str = "created_at", order: str = "desc", + after: str | None = None, ) -> ListAssetsResult: + """List assets with optional cursor pagination. + + When ``after`` is supplied it overrides ``offset``. The cursor's sort field + must match ``sort`` and be in the cursor-supported allowlist; mismatches + raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR. + """ + cursor_value: object | None = None + cursor_id: str | None = None + # Mint next_cursor on every page where the sort is cursor-supported, not + # only when the request itself arrived with a cursor. Otherwise a first + # request (no `after`) returns next_cursor=None and the client can never + # enter cursor mode. + mint_cursor = sort in _CURSOR_SORT_FIELDS + + if after is not None: + if sort not in _CURSOR_SORT_FIELDS: + raise InvalidCursorError( + f"cursor pagination is not supported for sort={sort!r}" + ) + payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order) + if payload.sort_field != sort: + raise InvalidCursorError( + f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}" + ) + cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id + + # Over-fetch by one row so we can distinguish "exactly `limit` rows total + # remaining" from "more rows past this page" without a second query. Drop + # the sentinel before returning. + fetch_limit = limit + 1 if mint_cursor else limit + with create_session() as session: refs, tag_map, total = list_references_page( session, @@ -261,12 +319,22 @@ def list_assets_page( exclude_tags=exclude_tags, name_contains=name_contains, metadata_filter=metadata_filter, - limit=limit, + limit=fetch_limit, offset=offset, sort=sort, order=order, + after_cursor_value=cursor_value, + after_cursor_id=cursor_id, ) + next_cursor: str | None = None + if mint_cursor and len(refs) > limit: + # There's at least one more row past this page — mint a cursor from + # the last row of the page (i.e. index `limit - 1`, since we + # over-fetched), and drop the sentinel. + next_cursor = _encode_next_cursor(refs[limit - 1], sort, order) + refs = refs[:limit] + items: list[AssetSummaryData] = [] for ref in refs: items.append( @@ -277,7 +345,39 @@ def list_assets_page( ) ) - return ListAssetsResult(items=items, total=total) + return ListAssetsResult(items=items, total=total, next_cursor=next_cursor) + + +def _resolve_cursor_value(payload: CursorPayload) -> object: + """Map a decoded cursor payload to a column-typed Python value.""" + if payload.sort_field in ("created_at", "updated_at"): + # DB stores naive UTC; strip tzinfo so the comparison binds against a + # `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift. + return decode_cursor_time(payload).replace(tzinfo=None) + if payload.sort_field == "size": + return decode_cursor_int(payload) + return payload.value # name, str-typed + + +def _encode_next_cursor(ref, sort: str, order: str) -> str | None: + """Mint a cursor pointing at *ref* for the given sort dimension. + + Returns None when the boundary row carries a NULL sort value (e.g. an asset + record whose size_bytes hasn't been backfilled). Continuing pagination + across a NULL boundary is undefined under keyset ordering — better to + truncate cleanly here than to mint a cursor that mis-positions. + """ + if sort == "name": + return encode_cursor("name", ref.name, ref.id, order=order) + if sort == "size": + if ref.asset is None or ref.asset.size_bytes is None: + return None + return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order) + # created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding. + value = ref.created_at if sort == "created_at" else ref.updated_at + if value is None: + return None + return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order) def resolve_hash_to_path( diff --git a/app/assets/services/cursor.py b/app/assets/services/cursor.py new file mode 100644 index 000000000..6c7791528 --- /dev/null +++ b/app/assets/services/cursor.py @@ -0,0 +1,213 @@ +"""Opaque keyset-pagination cursor for /api/assets. + +Payload JSON uses short keys to keep the encoded length small: + + {"s": , "v": , "id": , "o": } + +The `o` key binds the cursor to the sort direction it was minted under, +so replaying a `desc` cursor against an `asc` request fails with +``INVALID_CURSOR`` rather than silently walking the wrong direction. +`o` is mandatory on every payload — a cursor without it is rejected as +malformed. + +Encoding is base64url with no padding. Cursors are opaque tokens: the +payload format is internal to this server, and clients must treat a +cursor as a black box handed back via `next_cursor`. No byte-level +compatibility with any other implementation is required. + +Time values are serialized as Unix microseconds (UTC) — microsecond +precision is sufficient to round-trip the timestamps stored by the +database without rounding rows in the same millisecond bucket. +""" +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Iterable, Optional + + +class InvalidCursorError(ValueError): + """Raised on a malformed, oversized, or unsupported-sort-field cursor. + + Map to a 400 response with code ``INVALID_CURSOR`` at the handler. + """ + + +# Wire-format length caps. Cursors are user-controlled, so caps protect the +# decode path from oversized allocations and downstream SQL predicates from +# unbounded strings. +# +# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max +# (`String(512)`) — otherwise a long-named asset would mint a cursor the same +# server then refuses on the next request. +# +# MAX_ENCODED_CURSOR_LENGTH is the decode-path guard, sized comfortably above +# the largest cursor the per-field caps can produce. Worst case is value + id +# at their caps with every character JSON-escaping to the six-byte `\uXXXX` +# form (control characters), which is ~5.2 KB once base64url-encoded. At 8192 +# the encoder can never mint a cursor that exceeds it, so a freshly minted +# cursor always decodes on the next request and there is no user-visible +# "cursor too long" failure. +MAX_ENCODED_CURSOR_LENGTH = 8192 +MAX_CURSOR_VALUE_LENGTH = 512 +MAX_CURSOR_ID_LENGTH = 128 + + +@dataclass(frozen=True) +class CursorPayload: + sort_field: str + value: str + id: str + order: str + + +_VALID_ORDERS = ("asc", "desc") + + +def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str: + """Encode a cursor payload as a base64url (no-padding) string. + + `order` binds the cursor to the sort direction it was minted under so a + later request with a flipped `order` query parameter is rejected with + ``INVALID_CURSOR`` rather than silently walking the wrong direction. + """ + if order not in _VALID_ORDERS: + raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}") + # Symmetric input validation: the encoder must reject anything the + # decoder rejects, or the same server will mint cursors it then 400s on + # the next request. + if not id: + raise InvalidCursorError("id must be non-empty") + if len(id) > MAX_CURSOR_ID_LENGTH: + raise InvalidCursorError("id exceeds maximum length") + if len(value) > MAX_CURSOR_VALUE_LENGTH: + raise InvalidCursorError("value exceeds maximum length") + payload = {"s": sort_field, "v": value, "id": id, "o": order} + raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False) + # No mint-time length guard is needed: the per-field caps above bound the + # encoded length well below MAX_ENCODED_CURSOR_LENGTH (see its definition), + # so the encoder can never produce a cursor the decode path would reject. + return base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii") + + +def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str: + """Encode a time-typed cursor at Unix microsecond precision. + + Accepts an aware datetime (any timezone) and normalizes to UTC. Naive + datetimes are rejected so callers can't accidentally encode the local + wall-clock value of a UTC-stored timestamp. + """ + if t.tzinfo is None: + raise ValueError("encode_cursor_from_time requires an aware datetime") + micros = _datetime_to_unix_micros(t.astimezone(timezone.utc)) + return encode_cursor(sort_field, str(micros), id, order=order) + + +def decode_cursor( + cursor: str, + allowed_sort_fields: Iterable[str], + expected_order: str | None = None, +) -> CursorPayload: + """Parse an opaque cursor. + + ``allowed_sort_fields`` is the endpoint's accepted sort-field list — a + cursor carrying a field outside this set is rejected so a cursor minted + for one column can't be replayed against another (e.g. a ``created_at`` + timestamp string compared against a ``name`` column). + + ``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the + payload's ``o`` field. ``o`` is required on every payload; a cursor + missing it is rejected as malformed. + + Passing no allowed fields rejects every cursor. + """ + if len(cursor) > MAX_ENCODED_CURSOR_LENGTH: + raise InvalidCursorError("cursor exceeds maximum length") + + try: + # urlsafe_b64decode requires correct padding; we strip on encode, so + # restore the trailing '=' pad here. + padding = "=" * (-len(cursor) % 4) + raw = base64.urlsafe_b64decode(cursor + padding) + except (ValueError, base64.binascii.Error) as e: + raise InvalidCursorError(f"encoding: {e}") from e + + try: + decoded = json.loads(raw) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise InvalidCursorError(f"payload: {e}") from e + + if not isinstance(decoded, dict): + raise InvalidCursorError("payload: expected object") + + sort_field = decoded.get("s") + value = decoded.get("v") + id = decoded.get("id") + order = decoded.get("o") + + if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str): + raise InvalidCursorError("payload: missing or non-string s/v/id") + + if id == "": + raise InvalidCursorError("missing id") + if len(id) > MAX_CURSOR_ID_LENGTH: + raise InvalidCursorError("id exceeds maximum length") + if len(value) > MAX_CURSOR_VALUE_LENGTH: + raise InvalidCursorError("value exceeds maximum length") + + if sort_field not in allowed_sort_fields: + raise InvalidCursorError(f"unsupported sort field {sort_field!r}") + + if not isinstance(order, str): + raise InvalidCursorError("missing or non-string o") + if order not in _VALID_ORDERS: + raise InvalidCursorError(f"unsupported order {order!r}") + if expected_order is not None and order != expected_order: + raise InvalidCursorError( + f"cursor order {order!r} does not match request order {expected_order!r}" + ) + + return CursorPayload(sort_field=sort_field, value=value, id=id, order=order) + + +def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime: + """Parse a time-typed cursor value as Unix microseconds, returning UTC.""" + if payload is None: + raise InvalidCursorError("nil cursor payload") + try: + micros = int(payload.value) + except ValueError as e: + raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e + try: + return _unix_micros_to_datetime(micros) + except (OverflowError, OSError, ValueError) as e: + # Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up + # in fromtimestamp / datetime construction. Map to 400, not 500. + raise InvalidCursorError(f"value is out of representable range: {e}") from e + + +def decode_cursor_int(payload: Optional[CursorPayload]) -> int: + """Parse a cursor value as a base-10 integer.""" + if payload is None: + raise InvalidCursorError("nil cursor payload") + try: + return int(payload.value) + except ValueError as e: + raise InvalidCursorError(f"value is not a valid integer: {e}") from e + + +_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc) + + +def _datetime_to_unix_micros(t: datetime) -> int: + """Convert an aware UTC datetime to Unix microseconds (integer math).""" + delta = t - _EPOCH + return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds + + +def _unix_micros_to_datetime(micros: int) -> datetime: + """Convert Unix microseconds to a UTC datetime, preserving precision.""" + seconds, micro_remainder = divmod(micros, 1_000_000) + return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder) diff --git a/app/assets/services/image_dimensions.py b/app/assets/services/image_dimensions.py new file mode 100644 index 000000000..ccd97399a --- /dev/null +++ b/app/assets/services/image_dimensions.py @@ -0,0 +1,63 @@ +"""Image dimension extraction for asset ingest. + +Reads only the image header via Pillow to capture width/height cheaply, +without a full pixel decode. Returns a metadata dict suitable for merging +into ``AssetReference.system_metadata``. +""" +from __future__ import annotations + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def extract_image_dimensions( + file_path: str, mime_type: str | None = None +) -> dict[str, Any] | None: + """Extract image dimensions for the file at ``file_path``. + + Args: + file_path: Absolute path to a file on disk. + mime_type: Optional MIME type hint. When provided and not prefixed + with ``image/``, extraction is skipped without touching the file. + + Returns: + ``{"kind": "image", "width": W, "height": H}`` when the file is a + recognizable image with positive dimensions, otherwise ``None``. + + The dict shape is intended to be merged into ``system_metadata`` so the + asset response surfaces ``metadata.kind`` plus dimension fields for image + assets. Forward-compatible: future media kinds (e.g. ``"video"`` with + duration/fps) can extend this shape without schema changes. + """ + if mime_type is not None and not mime_type.startswith("image/"): + return None + + try: + from PIL import Image, UnidentifiedImageError + except ImportError: + logger.debug( + "Pillow not available; skipping image dimension extraction for %s", + file_path, + ) + return None + + try: + with Image.open(file_path) as img: + width, height = img.size + except (OSError, UnidentifiedImageError, ValueError) as exc: + logger.debug( + "Failed to read image dimensions from %s: %s", file_path, exc + ) + return None + + if ( + not isinstance(width, int) + or not isinstance(height, int) + or width <= 0 + or height <= 0 + ): + return None + + return {"kind": "image", "width": width, "height": height} diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index f0b070517..3b6dc237c 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -17,9 +17,11 @@ from app.assets.database.queries import ( get_reference_by_file_path, get_reference_tags, get_or_create_reference, + list_references_by_asset_id, reference_exists, remove_missing_tag_for_asset_id, set_reference_metadata, + set_reference_system_metadata, set_reference_tags, update_asset_hash_and_mime, upsert_asset, @@ -29,6 +31,7 @@ from app.assets.database.queries import ( from app.assets.helpers import get_utc_now, normalize_tags from app.assets.services.bulk_ingest import batch_insert_seed_assets from app.assets.services.file_utils import get_size_and_mtime_ns +from app.assets.services.image_dimensions import extract_image_dimensions from app.assets.services.path_utils import ( compute_relative_filename, get_name_and_tags_from_asset_path, @@ -118,6 +121,14 @@ def _ingest_file_from_path( user_metadata=user_metadata, ) + _maybe_store_image_dimensions( + session, + reference_id=reference_id, + file_path=locator, + mime_type=mime_type, + current_system_metadata=ref.system_metadata, + ) + try: remove_missing_tag_for_asset_id(session, asset_id=asset.id) except Exception: @@ -288,6 +299,13 @@ def _register_existing_asset( user_metadata=new_meta, ) + _backfill_image_dimensions_from_siblings( + session, + asset_id=asset.id, + new_reference_id=ref.id, + current_system_metadata=ref.system_metadata, + ) + if tags is not None: set_reference_tags( session, @@ -334,6 +352,87 @@ def _update_metadata_with_filename( ) +_IMAGE_DIMENSION_KEYS = ("kind", "width", "height") + + +def _maybe_store_image_dimensions( + session: Session, + reference_id: str, + file_path: str, + mime_type: str | None, + current_system_metadata: dict | None, +) -> None: + """Populate ``kind``/``width``/``height`` on system_metadata for image refs. + + Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written + safetensors metadata, download provenance) are preserved by merge. + """ + if not mime_type or not mime_type.startswith("image/"): + return + + dims = extract_image_dimensions(file_path, mime_type=mime_type) + if not dims: + return + + current = current_system_metadata or {} + merged = dict(current) + merged.update(dims) + if merged != current: + set_reference_system_metadata( + session, + reference_id=reference_id, + system_metadata=merged, + ) + + +def _backfill_image_dimensions_from_siblings( + session: Session, + asset_id: str, + new_reference_id: str, + current_system_metadata: dict | None, +) -> None: + """Copy image dimension keys from any sibling reference of the same asset. + + The from-hash path doesn't read the file bytes, so dimensions can't be + extracted there directly. When another reference of the same asset already + carries image dimensions, copy them onto the new reference so consumers + see consistent metadata regardless of how the asset was registered. + + Best-effort: missing siblings, non-image siblings, or absent dimension + keys leave the target reference unchanged. + """ + current = current_system_metadata or {} + if current.get("kind") == "image" and "width" in current and "height" in current: + return + + for sibling in list_references_by_asset_id(session, asset_id): + if sibling.id == new_reference_id: + continue + meta = sibling.system_metadata or {} + if meta.get("kind") != "image": + continue + width = meta.get("width") + height = meta.get("height") + if ( + type(width) is not int + or type(height) is not int + or width <= 0 + or height <= 0 + ): + continue + merged = dict(current) + merged["kind"] = "image" + merged["width"] = width + merged["height"] = height + if merged != current: + set_reference_system_metadata( + session, + reference_id=new_reference_id, + system_metadata=merged, + ) + return + + def _sanitize_filename(name: str | None, fallback: str) -> str: n = os.path.basename((name or "").strip() or fallback) return n if n else fallback diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py index 0eb128f58..4d2af8a02 100644 --- a/app/assets/services/schemas.py +++ b/app/assets/services/schemas.py @@ -56,7 +56,6 @@ class IngestResult: class TagUsage(NamedTuple): name: str - tag_type: str count: int @@ -71,6 +70,7 @@ class AssetSummaryData: class ListAssetsResult: items: list[AssetSummaryData] total: int + next_cursor: str | None = None @dataclass(frozen=True) diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py index 37b612753..5fa39d26a 100644 --- a/app/assets/services/tagging.py +++ b/app/assets/services/tagging.py @@ -75,7 +75,7 @@ def list_tags( owner_id=owner_id, ) - return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total + return [TagUsage(name, count) for name, count in rows], total def list_tag_histogram( diff --git a/comfy/cli_args.py b/comfy/cli_args.py index a4cabcc65..e7ee0d5eb 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -115,6 +115,7 @@ cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metav cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") +cache_group.add_argument("--high-ram", action="store_true", help="Can improve performance slightly on high RAM or on systems where pagefile use is preferred over model loading.") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -133,7 +134,7 @@ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disabl parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.") manager_group = parser.add_mutually_exclusive_group() manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.") -manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager") +manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager. Implies --enable-manager.") vram_group = parser.add_mutually_exclusive_group() @@ -166,6 +167,8 @@ class PerformanceFeature(enum.Enum): parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) +parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.") + parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.") parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") @@ -247,6 +250,9 @@ else: if args.cache_ram is not None and len(args.cache_ram) > 2: parser.error("--cache-ram accepts at most two values: active GB and inactive GB") +if args.high_ram: + args.cache_classic = True + if args.windows_standalone_build: args.auto_launch = True @@ -256,6 +262,10 @@ if args.disable_auto_launch: if args.force_fp16: args.fp16_unet = True +# '--enable-manager-legacy-ui' is meaningless unless the manager is enabled, so imply '--enable-manager'. +if args.enable_manager_legacy_ui: + args.enable_manager = True + # '--fast' is not provided, use an empty set if args.fast is None: diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py index ee86f8309..53e4fdb6c 100644 --- a/comfy/image_encoders/dino2.py +++ b/comfy/image_encoders/dino2.py @@ -1,7 +1,13 @@ import torch +import torch.nn.functional as F + from comfy.text_encoders.bert import BertAttention import comfy.model_management from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.ldm.depth_anything_3.reference_view_selector import ( + select_reference_view, reorder_by_reference, restore_original_order, + THRESH_FOR_REF_SELECTION, +) class Dino2AttentionOutput(torch.nn.Module): @@ -14,13 +20,41 @@ class Dino2AttentionOutput(torch.nn.Module): class Dino2AttentionBlock(torch.nn.Module): - def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations): + def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations, + qk_norm=False): super().__init__() + self.heads = heads + self.head_dim = embed_dim // heads self.attention = BertAttention(embed_dim, heads, dtype, device, operations) self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations) + if qk_norm: + self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) + self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device) + else: + self.q_norm = None + self.k_norm = None - def forward(self, x, mask, optimized_attention): - return self.output(self.attention(x, mask, optimized_attention)) + def forward(self, x, mask, optimized_attention, pos=None, rope=None): + # Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions). + if self.q_norm is None and rope is None: + return self.output(self.attention(x, mask, optimized_attention)) + + # DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE. + attn = self.attention + B, N, C = x.shape + h = self.heads + d = self.head_dim + q = attn.query(x).view(B, N, h, d).transpose(1, 2) + k = attn.key(x).view(B, N, h, d).transpose(1, 2) + v = attn.value(x).view(B, N, h, d).transpose(1, 2) + if self.q_norm is not None: + q = self.q_norm(q) + k = self.k_norm(k) + if rope is not None and pos is not None: + q = rope(q, pos) + k = rope(k, pos) + out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + return self.output(out) class LayerScale(torch.nn.Module): @@ -64,9 +98,11 @@ class SwiGLUFFN(torch.nn.Module): class Dino2Block(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn): + def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn, + qk_norm=False): super().__init__() - self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations) + self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations, + qk_norm=qk_norm) self.layer_scale1 = LayerScale(dim, dtype, device, operations) self.layer_scale2 = LayerScale(dim, dtype, device, operations) if use_swiglu_ffn: @@ -76,19 +112,90 @@ class Dino2Block(torch.nn.Module): self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) - def forward(self, x, optimized_attention): - x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention)) + def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None): + x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention, + pos=pos, rope=rope)) x = x + self.layer_scale2(self.mlp(self.norm2(x))) return x -class Dino2Encoder(torch.nn.Module): - def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn): +# ----------------------------------------------------------------------------- +# 2D Rotary position embedding (DA3 extension) +# ----------------------------------------------------------------------------- + + +class _PositionGetter: + """Cache (h, w) -> flat (y, x) position grid used to feed ``rope``.""" + + def __init__(self): + self._cache: dict = {} + + def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor: + key = (height, width, device) + if key not in self._cache: + y = torch.arange(height, device=device) + x = torch.arange(width, device=device) + self._cache[key] = torch.cartesian_prod(y, x) + cached = self._cache[key] + return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(torch.nn.Module): + """2D RoPE used by DA3-Small/Base. No learnable parameters.""" + + def __init__(self, frequency: float = 100.0): super().__init__() - self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) - for _ in range(num_layers)]) + self.base_frequency = frequency + self._freq_cache: dict = {} + + def _components(self, dim: int, seq_len: int, device, dtype): + key = (dim, seq_len, device, dtype) + if key not in self._freq_cache: + exp = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency ** exp) + pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + ang = torch.einsum("i,j->ij", pos, inv_freq) + ang = ang.to(dtype) + ang = torch.cat((ang, ang), dim=-1) + self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype)) + return self._freq_cache[key] + + @staticmethod + def _rotate(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] + x1, x2 = x[..., : d // 2], x[..., d // 2:] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d(self, tokens, positions, cos_c, sin_c): + cos = F.embedding(positions, cos_c)[:, None, :, :] + sin = F.embedding(positions, sin_c)[:, None, :, :] + return (tokens * cos) + (self._rotate(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + feature_dim = tokens.size(-1) // 2 + max_pos = int(positions.max()) + 1 + cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype) + v, h = tokens.chunk(2, dim=-1) + v = self._apply_1d(v, positions[..., 0], cos_c, sin_c) + h = self._apply_1d(h, positions[..., 1], cos_c, sin_c) + return torch.cat((v, h), dim=-1) + + +class Dino2Encoder(torch.nn.Module): + def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn, + qknorm_start: int = -1): + super().__init__() + self.layer = torch.nn.ModuleList([ + Dino2Block( + dim, num_heads, layer_norm_eps, dtype, device, operations, + use_swiglu_ffn=use_swiglu_ffn, + qk_norm=(qknorm_start != -1 and i >= qknorm_start), + ) + for i in range(num_layers) + ]) def forward(self, x, intermediate_output=None): + # Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions). optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) if intermediate_output is not None: @@ -122,16 +229,27 @@ class Dino2PatchEmbeddings(torch.nn.Module): class Dino2Embeddings(torch.nn.Module): - def __init__(self, dim, dtype, device, operations): + def __init__(self, dim, dtype, device, operations, + patch_size: int = 14, image_size: int = 518, + use_mask_token: bool = True, + num_camera_tokens: int = 0): super().__init__() - patch_size = 14 - image_size = 518 self.patch_size = patch_size + self.image_size = image_size self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations) self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device)) self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key. - self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + if use_mask_token: + self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device)) + else: + self.mask_token = None + if num_camera_tokens > 0: + # DA3 stores (ref_token, src_token) pairs that get injected at the + # alt-attn boundary; see ``Dinov2Model._inject_camera_token``. + self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device)) + else: + self.camera_token = None def interpolate_pos_encoding(self, x, h_pixels, w_pixels): pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32) @@ -140,12 +258,22 @@ class Dino2Embeddings(torch.nn.Module): patch_pos = pos_embed[:, 1:] N = patch_pos.shape[1] M = int(N ** 0.5) + assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})" h0 = h_pixels // self.patch_size w0 = w_pixels // self.patch_size - scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0). + # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0). + # scale_factor is (height_scale, width_scale) -- height MUST come first; + # swapping these only happens to work for square inputs and breaks + # non-square paths like DA3-Small / DA3-Base multi-view. + scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2) patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False) + assert (h0, w0) == patch_pos.shape[-2:], ( + f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match " + f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); " + f"check scale_factor axis order and +0.1 rounding workaround" + ) patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2) return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype) @@ -168,12 +296,51 @@ class Dinov2Model(torch.nn.Module): heads = config_dict["num_attention_heads"] layer_norm_eps = config_dict["layer_norm_eps"] use_swiglu_ffn = config_dict["use_swiglu_ffn"] + patch_size = config_dict.get("patch_size", 14) + image_size = config_dict.get("image_size", 518) + use_mask_token = config_dict.get("use_mask_token", True) - self.embeddings = Dino2Embeddings(dim, dtype, device, operations) - self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn) + # DA3 extensions (all default to disabled). + self.alt_start = config_dict.get("alt_start", -1) + self.qknorm_start = config_dict.get("qknorm_start", -1) + self.rope_start = config_dict.get("rope_start", -1) + self.cat_token = config_dict.get("cat_token", False) + rope_freq = config_dict.get("rope_freq", 100.0) + + self.embed_dim = dim + self.patch_size = patch_size + self.num_register_tokens = 0 + self.patch_start_idx = 1 + + if self.rope_start != -1 and rope_freq > 0: + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) + self._position_getter = _PositionGetter() + else: + self.rope = None + self._position_getter = None + + # camera_token shape: (1, 2, dim) -> (ref_token, src_token). + num_cam_tokens = 2 if self.alt_start != -1 else 0 + + self.embeddings = Dino2Embeddings( + dim, dtype, device, operations, + patch_size=patch_size, image_size=image_size, + use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens, + ) + self.encoder = Dino2Encoder( + dim, heads, layer_norm_eps, num_layers, dtype, device, operations, + use_swiglu_ffn=use_swiglu_ffn, + qknorm_start=self.qknorm_start, + ) self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device) def forward(self, pixel_values, attention_mask=None, intermediate_output=None): + if self.alt_start != -1: + raise RuntimeError( + "Dinov2Model.forward() is the backward-compatible CLIP-vision path and does not " + "apply DA3 extensions (RoPE, alternating attention, camera-token injection). " + "Use get_intermediate_layers_da3() for Depth Anything 3 models." + ) x = self.embeddings(pixel_values) x, i = self.encoder(x, intermediate_output=intermediate_output) x = self.layernorm(x) @@ -181,6 +348,7 @@ class Dinov2Model(torch.nn.Module): return x, i, pooled_output, None def get_intermediate_layers(self, pixel_values, indices, apply_norm=True): + """Single-view multi-layer feature extraction.""" x = self.embeddings(pixel_values) optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) n_layers = len(self.encoder.layer) @@ -197,3 +365,132 @@ class Dinov2Model(torch.nn.Module): if i >= max_idx: break return [cache[i] for i in resolved] + + # ------------------------------------------------------------------ + # Depth Anything 3 forward + # ------------------------------------------------------------------ + def _prepare_rope_positions(self, B, S, H, W, device): + if self.rope is None: + return None, None + ph, pw = H // self.patch_size, W // self.patch_size + pos = self._position_getter(B * S, ph, pw, device=device) + # Shift so the cls/cam token at position 0 is reserved for "no diff". + pos = pos + 1 + cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype) + # Per-view local: real grid positions for patches, 0 for cls token. + pos_local = torch.cat([cls_pos, pos], dim=1) + # Global (across views): same grid positions; cls token still at 0, + # but patches share the same positions in every view. + pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1) + return pos_local, pos_global + + def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, cam_token: "torch.Tensor | None") -> torch.Tensor: + # x: (B, S, N, C). Replace token at index 0 with the camera token. + if cam_token is not None: + inj = cam_token + else: + ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype) + ref_token = ct[:, :1].expand(B, -1, -1) + src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1) + inj = torch.cat([ref_token, src_token], dim=1) + x = x.clone() + x[:, :, 0] = inj + return x + + def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None, ref_view_strategy="saddle_balanced", export_feat_layers=None): + """Multi-view multi-layer feature extraction used by Depth Anything 3.""" + if pixel_values.ndim == 4: + pixel_values = pixel_values.unsqueeze(1) + assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \ + f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}" + B, S, _, H, W = pixel_values.shape + + # Patch + cls + (interpolated) pos embed for each view. + x = pixel_values.reshape(B * S, 3, H, W) + x = self.embeddings(x) # (B*S, 1+N, C) + x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C) + + pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device) + # optimized_attention is only used by blocks without QK-norm/RoPE + # (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA. + optimized_attention = optimized_attention_for_device(x.device, False, small_input=True) + + out_set = set(out_layers) + export_set = set(export_feat_layers) if export_feat_layers else set() + outputs: list[torch.Tensor] = [] + aux_outputs: list[torch.Tensor] = [] + local_x = x + b_idx = None + + + for i, blk in enumerate(self.encoder.layer): + apply_rope = self.rope is not None and i >= self.rope_start + block_rope = self.rope if apply_rope else None + l_pos = pos_local if apply_rope else None + g_pos = pos_global if apply_rope else None + + # Reference-view selection threshold: matches the upstream constant + # THRESH_FOR_REF_SELECTION = 3. Skipped when a user-supplied + # cam_token is provided (camera info already pins the geometry). + if (self.alt_start != -1 and i == self.alt_start - 1 and S >= THRESH_FOR_REF_SELECTION and cam_token is None): + b_idx = select_reference_view(x, strategy=ref_view_strategy) + x = reorder_by_reference(x, b_idx) + local_x = reorder_by_reference(local_x, b_idx) + + if self.alt_start != -1 and i == self.alt_start: + x = self._inject_camera_token(x, B, S, cam_token) + + if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1): + # Global attention across views: flatten S into the seq dim. + t = x.reshape(B, S * x.shape[-2], x.shape[-1]) + p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None + t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) + x = t.reshape(B, S, x.shape[-2], x.shape[-1]) + else: + # Per-view local attention. + t = x.reshape(B * S, x.shape[-2], x.shape[-1]) + p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None + t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope) + x = t.reshape(B, S, x.shape[-2], x.shape[-1]) + local_x = x + + if i in out_set: + if self.cat_token: + out_x = torch.cat([local_x, x], dim=-1) + else: + out_x = x + # Restore original view order on the way out so heads see views + # in the user's expected order. + if b_idx is not None and self.alt_start != -1: + out_x = restore_original_order(out_x, b_idx) + outputs.append(out_x) + + if i in export_set: + aux = x + if b_idx is not None and self.alt_start != -1: + aux = restore_original_order(aux, b_idx) + aux_outputs.append(aux) + + # Apply final norm. When cat_token is set, only the right half + # ("global" features) is normalised; the left half is left as-is to + # match the upstream DA3 head signature. + normed: list[torch.Tensor] = [] + cls_tokens: list[torch.Tensor] = [] + for out_x in outputs: + cls_tokens.append(out_x[:, :, 0]) + if out_x.shape[-1] == self.embed_dim: + normed.append(self.layernorm(out_x)) + elif out_x.shape[-1] == self.embed_dim * 2: + left = out_x[..., :self.embed_dim] + right = self.layernorm(out_x[..., self.embed_dim:]) + normed.append(torch.cat([left, right], dim=-1)) + else: + raise ValueError(f"Unexpected token width: {out_x.shape[-1]}") + + # Drop cls/cam token from the patch sequence. + normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed] + + # Final layernorm + drop cls token from auxiliary features too. + aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :] + for o in aux_outputs] + return list(zip(normed, cls_tokens)), aux_normed diff --git a/comfy/ldm/colormap.py b/comfy/ldm/colormap.py new file mode 100644 index 000000000..1f4d88bd9 --- /dev/null +++ b/comfy/ldm/colormap.py @@ -0,0 +1,25 @@ +"""Colormap utilities for depth and geometry visualisation.""" + +from __future__ import annotations + +import torch + + +def turbo(x: torch.Tensor) -> torch.Tensor: + """Anton Mikhailov polynomial approximation of the Turbo colormap. + + Args: + x: Float tensor with values in [0, 1]. + + Returns: + RGB tensor of the same shape as ``x`` with a trailing size-3 dimension. + """ + x = x.clamp(0.0, 1.0) + x2 = x * x + x3 = x2 * x + x4 = x2 * x2 + x5 = x4 * x + r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5 + g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5 + b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5 + return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0) diff --git a/comfy/ldm/depth_anything_3/camera.py b/comfy/ldm/depth_anything_3/camera.py new file mode 100644 index 000000000..65a57d66f --- /dev/null +++ b/comfy/ldm/depth_anything_3/camera.py @@ -0,0 +1,177 @@ +"""Camera-token encoder and decoder for Depth Anything 3.""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention_for_device +from .transform import affine_inverse, extri_intri_to_pose_encoding + + +# ----------------------------------------------------------------------- +# Building blocks (mirror depth_anything_3.model.utils.{attention,block}) +# ----------------------------------------------------------------------- + + +class _Mlp(nn.Module): + """Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``.""" + + def __init__(self, in_features, hidden_features=None, out_features=None, *, device=None, dtype=None, operations=None): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = operations.Linear(in_features, hidden_features, bias=True, device=device, dtype=dtype) + self.fc2 = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype) + + def forward(self, x): + return self.fc2(F.gelu(self.fc1(x))) + + +class _LayerScale(nn.Module): + """Per-channel learnable scaling. Matches upstream LayerScale.""" + + def __init__(self, dim, *, device=None, dtype=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + + def forward(self, x): + return x * self.gamma.to(dtype=x.dtype, device=x.device) + + +class _Attention(nn.Module): + """ Self-attention with fused QKV projection. Mirrors upstream utils.attention.Attention; + Layout matches the HF safetensors (attn.qkv.{weight,bias} and attn.proj.{weight,bias}).""" + + def __init__(self, dim, num_heads, *, device=None, dtype=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = operations.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, C) + q, k, v = qkv.unbind(2) # each (B, N, C) + attn_fn = optimized_attention_for_device(x.device, small_input=True) + out = attn_fn(q, k, v, heads=self.num_heads) + return self.proj(out) + + +class _Block(nn.Module): + """Pre-norm transformer block with LayerScale. Used by :class:CameraEnc. Layout follows upstream utils.block.Block.""" + + def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, *, device=None, dtype=None, operations=None): + super().__init__() + self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.attn = _Attention(dim, num_heads, device=device, dtype=dtype, operations=operations) + self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity() + self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype, operations=operations) + self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity() + + def forward(self, x): + x = x + self.ls1(self.attn(self.norm1(x))) + x = x + self.ls2(self.mlp(self.norm2(x))) + return x + + +class CameraEnc(nn.Module): + """Encode per-view (extrinsics, intrinsics) into a camera token. + + Maps a 9-D pose-encoding vector through a small MLP up to the backbone's + ``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output + has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start`` + of the DINOv2 backbone in place of the cls token. + + Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly. + """ + + def __init__( + self, + dim_out: int = 1024, + dim_in: int = 9, + trunk_depth: int = 4, + target_dim: int = 9, + num_heads: int = 16, + mlp_ratio: int = 4, + init_values: float = 0.01, + *, + device=None, dtype=None, operations=None, + **_kwargs, + ): + super().__init__() + self.target_dim = target_dim + self.trunk_depth = trunk_depth + self.trunk = nn.Sequential(*[ + _Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio, + init_values=init_values, + device=device, dtype=dtype, operations=operations) + for _ in range(trunk_depth) + ]) + self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype) + self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype) + self.pose_branch = _Mlp( + in_features=dim_in, + hidden_features=dim_out // 2, + out_features=dim_out, + device=device, dtype=dtype, operations=operations, + ) + + def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor, + image_size_hw) -> torch.Tensor: + """Encode camera parameters into ``(B, S, dim_out)`` tokens.""" + c2ws = affine_inverse(extrinsics) + pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw) + tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype)) + tokens = self.token_norm(tokens) + tokens = self.trunk(tokens) + tokens = self.trunk_norm(tokens) + return tokens + + +class CameraDec(nn.Module): + """Decode the final cam token into a 9-D pose encoding. + + Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is + always predicted by the network; the quaternion and FoV can either be + predicted or supplied via ``camera_encoding`` (used at training time + when GT cameras are available -- not exercised at inference here). + + Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly. + """ + + def __init__(self, dim_in: int = 1536, + *, device=None, dtype=None, operations=None, **_kwargs): + super().__init__() + d = dim_in + self.backbone = nn.Sequential( + operations.Linear(d, d, device=device, dtype=dtype), + nn.ReLU(), + operations.Linear(d, d, device=device, dtype=dtype), + nn.ReLU(), + ) + self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype) + self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype) + self.fc_fov = nn.Sequential( + operations.Linear(d, 2, device=device, dtype=dtype), + nn.ReLU(), + ) + + def forward(self, feat: torch.Tensor, + camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor: + """Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc.""" + B, N = feat.shape[:2] + feat = feat.reshape(B * N, -1) + feat = self.backbone(feat) + out_t = self.fc_t(feat.float()).reshape(B, N, 3) + if camera_encoding is None: + out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4) + out_fov = self.fc_fov(feat.float()).reshape(B, N, 2) + else: + out_qvec = camera_encoding[..., 3:7] + out_fov = camera_encoding[..., -2:] + return torch.cat([out_t, out_qvec, out_fov], dim=-1) diff --git a/comfy/ldm/depth_anything_3/dpt.py b/comfy/ldm/depth_anything_3/dpt.py new file mode 100644 index 000000000..fb940873b --- /dev/null +++ b/comfy/ldm/depth_anything_3/dpt.py @@ -0,0 +1,489 @@ +"""DPT / DualDPT heads for Depth Anything 3.""" + +from __future__ import annotations + +from typing import List, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Permute(nn.Module): + def __init__(self, dims: Tuple[int, ...]): + super().__init__() + self.dims = dims + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.permute(*self.dims) + + +def _custom_interpolate( + x: torch.Tensor, + size: Optional[Tuple[int, int]] = None, + scale_factor: Optional[float] = None, + mode: str = "bilinear", + align_corners: bool = True, +) -> torch.Tensor: + if size is None: + assert scale_factor is not None + size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor)) + INT_MAX = 1610612736 + total = size[0] * size[1] * x.shape[0] * x.shape[1] + if total > INT_MAX: + chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0) + outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks] + return torch.cat(outs, dim=0).contiguous() + return F.interpolate(x, size=size, mode=mode, align_corners=align_corners) + + +def _create_uv_grid(width: int, height: int, aspect_ratio: float, dtype, device) -> torch.Tensor: + """Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span).""" + diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5 + span_x = aspect_ratio / diag_factor + span_y = 1.0 / diag_factor + left_x = -span_x * (width - 1) / width + right_x = span_x * (width - 1) / width + top_y = -span_y * (height - 1) / height + bottom_y = span_y * (height - 1) / height + x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device) + y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device) + uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy") + return torch.stack((uu, vv), dim=-1) # (H, W, 2) + + +def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor: + omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) + omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0)) + pos = pos.reshape(-1) + out = torch.einsum("m,d->md", pos, omega) + return torch.cat([out.sin(), out.cos()], dim=1).float() + + +def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100.0) -> torch.Tensor: + H, W, _ = pos_grid.shape + pos_flat = pos_grid.reshape(-1, 2) + emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) + emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) + emb = torch.cat([emb_x, emb_y], dim=-1) + return emb.view(H, W, embed_dim) + + +def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor: + """Stateless UV positional embedding added to a feature map (B, C, h, w).""" + pw, ph = x.shape[-1], x.shape[-2] + pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device) + pe = _position_grid_to_embed(pe, x.shape[1]) * ratio + pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype) + return x + pe + + +def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor: + act = (activation or "linear").lower() + if act == "exp": + return torch.exp(x) + if act == "expp1": + return torch.exp(x) + 1 + if act == "expm1": + return torch.expm1(x) + if act == "relu": + return torch.relu(x) + if act == "sigmoid": + return torch.sigmoid(x) + if act == "softplus": + return F.softplus(x) + if act == "tanh": + return torch.tanh(x) + return x + + +# ----------------------------------------------------------------------------- +# Fusion building blocks +# ----------------------------------------------------------------------------- + + +class ResidualConvUnit(nn.Module): + def __init__(self, features: int, device=None, dtype=None, operations=None): + super().__init__() + self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype) + self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype) + self.activation = nn.ReLU(inplace=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.activation(x) + out = self.conv1(out) + out = self.activation(out) + out = self.conv2(out) + return out + x + + +class FeatureFusionBlock(nn.Module): + def __init__(self, features: int, has_residual: bool = True, align_corners: bool = True, device=None, dtype=None, operations=None): + super().__init__() + self.align_corners = align_corners + self.has_residual = has_residual + if has_residual: + self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations) + else: + self.resConfUnit1 = None + self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations) + self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, device=device, dtype=dtype) + + def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor: + y = xs[0] + if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None: + y = y + self.resConfUnit1(xs[1]) + y = self.resConfUnit2(y) + if size is None: + up_kwargs = {"scale_factor": 2.0} + else: + up_kwargs = {"size": size} + y = _custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners) + y = self.out_conv(y) + return y + + +class _Scratch(nn.Module): + """Container that mirrors upstream ``scratch`` attribute layout.""" + + +def _make_scratch(in_shape: List[int], out_shape: int, device=None, dtype=None, operations=None) -> _Scratch: + scratch = _Scratch() + scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype) + scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype) + scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype) + scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype) + return scratch + + +def _make_fusion_block(features: int, has_residual: bool = True, device=None, dtype=None, operations=None) -> FeatureFusionBlock: + return FeatureFusionBlock(features, has_residual=has_residual, align_corners=True, device=device, dtype=dtype, operations=operations) + + +# ----------------------------------------------------------------------------- +# DPT (single head + optional sky head) -- used by DA3Mono/Metric +# ----------------------------------------------------------------------------- + + +class DPT(nn.Module): + """Single-head DPT used by DA3Mono-Large and DA3Metric-Large.""" + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 1, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = False, + down_ratio: int = 1, + head_name: str = "depth", + use_sky_head: bool = True, + sky_name: str = "sky", + sky_activation: str = "relu", + norm_type: str = "idt", + device=None, dtype=None, operations=None, + ): + super().__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + self.head_main = head_name + self.sky_name = sky_name + self.out_dim = output_dim + self.has_conf = output_dim > 1 + self.use_sky_head = use_sky_head + self.sky_activation = sky_activation + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + + if norm_type == "layer": + self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype) + else: + self.norm = nn.Identity() + + out_channels = list(out_channels) + self.projects = nn.ModuleList([ + operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype) + for oc in out_channels + ]) + self.resize_layers = nn.ModuleList([ + operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype), + operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype), + nn.Identity(), + operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype), + ]) + + self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations) + + head_features_1 = features + head_features_2 = 32 + self.scratch.output_conv1 = operations.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype, + ) + self.scratch.output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype), + ) + + if self.use_sky_head: + self.scratch.sky_output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype), + ) + + def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict: + # feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C) + B, S, N, C = feats[0][0].shape + feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats] + + ph, pw = H // self.patch_size, W // self.patch_size + resized = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats_flat[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw) + x = self.projects[stage_idx](x) + if self.pos_embed: + x = _add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) + resized.append(x) + + l1_rn = self.scratch.layer1_rn(resized[0]) + l2_rn = self.scratch.layer2_rn(resized[1]) + l3_rn = self.scratch.layer3_rn(resized[2]) + l4_rn = self.scratch.layer4_rn(resized[3]) + + out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:]) + out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:]) + out = self.scratch.refinenet1(out, l1_rn) + + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + fused = self.scratch.output_conv1(out) + fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + fused = _add_pos_embed(fused, W, H) + feat = fused + + main_logits = self.scratch.output_conv2(feat) + outs = {} + if self.has_conf: + fmap = main_logits.permute(0, 2, 3, 1) + pred = _apply_activation(fmap[..., :-1], self.activation) + conf = _apply_activation(fmap[..., -1], self.conf_activation) + outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1]) + outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:]) + else: + pred = _apply_activation(main_logits, self.activation) + outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:]) + + if self.use_sky_head: + sky_logits = self.scratch.sky_output_conv2(feat) + if self.sky_activation.lower() == "sigmoid": + sky = torch.sigmoid(sky_logits) + elif self.sky_activation.lower() == "relu": + sky = F.relu(sky_logits) + else: + sky = sky_logits + outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:]) + + return outs + + +# ----------------------------------------------------------------------------- +# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base +# ----------------------------------------------------------------------------- + + +class DualDPT(nn.Module): + """Two-head DPT used by DA3-Small / DA3-Base.""" + + def __init__( + self, + dim_in: int, + patch_size: int = 14, + output_dim: int = 2, + activation: str = "exp", + conf_activation: str = "expp1", + features: int = 256, + out_channels: Sequence[int] = (256, 512, 1024, 1024), + pos_embed: bool = True, + down_ratio: int = 1, + aux_pyramid_levels: int = 4, + aux_out1_conv_num: int = 5, + head_names: Tuple[str, str] = ("depth", "ray"), + device=None, dtype=None, operations=None, + ): + super().__init__() + self.patch_size = patch_size + self.activation = activation + self.conf_activation = conf_activation + self.pos_embed = pos_embed + self.down_ratio = down_ratio + self.aux_levels = aux_pyramid_levels + self.aux_out1_conv_num = aux_out1_conv_num + self.head_main, self.head_aux = head_names + self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3) + # Toggle the auxiliary ray branch at runtime. Default off (mono path). + # DepthAnything3Net flips this on when running multi-view + ray-pose. + self.enable_aux: bool = False + + self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype) + out_channels = list(out_channels) + self.projects = nn.ModuleList([ + operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype) + for oc in out_channels + ]) + self.resize_layers = nn.ModuleList([ + operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype), + operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype), + nn.Identity(), + operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype), + ]) + + self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations) + # Main fusion chain + self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations) + # Auxiliary fusion chain (separate copies) + self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations) + self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations) + + head_features_1 = features + head_features_2 = 32 + + # Main head neck + final projection + self.scratch.output_conv1 = operations.Conv2d( + head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1, + device=device, dtype=dtype, + ) + self.scratch.output_conv2 = nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype), + ) + + # Aux pre-head per level (multi-level pyramid) + self.scratch.output_conv1_aux = nn.ModuleList([ + self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations) + for _ in range(self.aux_levels) + ]) + + # Aux final projection per level (includes LayerNorm permute path). + ln_seq = [Permute((0, 2, 3, 1)), + operations.LayerNorm(head_features_2, device=device, dtype=dtype), + Permute((0, 3, 1, 2))] + self.scratch.output_conv2_aux = nn.ModuleList([ + nn.Sequential( + operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype), + *ln_seq, + nn.ReLU(inplace=False), + operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype), + ) + for _ in range(self.aux_levels) + ]) + + @staticmethod + def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential: + # aux_out1_conv_num=5 in all Apache-2.0 variants. + return nn.Sequential( + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype), + operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype), + ) + + def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict: + B, S, N, C = feats[0][0].shape + feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats] + + ph, pw = H // self.patch_size, W // self.patch_size + resized = [] + for stage_idx, take_idx in enumerate(self.intermediate_layer_idx): + x = feats_flat[take_idx][:, patch_start_idx:] + x = self.norm(x) + x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw) + x = self.projects[stage_idx](x) + if self.pos_embed: + x = _add_pos_embed(x, W, H) + x = self.resize_layers[stage_idx](x) + resized.append(x) + + l1_rn = self.scratch.layer1_rn(resized[0]) + l2_rn = self.scratch.layer2_rn(resized[1]) + l3_rn = self.scratch.layer3_rn(resized[2]) + l4_rn = self.scratch.layer4_rn(resized[3]) + + # Main pyramid (output_conv1 is applied inside the upstream `_fuse`, + # before interpolation -- replicate that order here). + m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:]) + if self.enable_aux: + a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:]) + aux_pyr = [a4] + m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:]) + if self.enable_aux: + aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:])) + m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:]) + if self.enable_aux: + aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:])) + m = self.scratch.refinenet1(m, l1_rn) + if self.enable_aux: + aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn)) + m = self.scratch.output_conv1(m) + + h_out = int(ph * self.patch_size / self.down_ratio) + w_out = int(pw * self.patch_size / self.down_ratio) + + m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True) + if self.pos_embed: + m = _add_pos_embed(m, W, H) + main_logits = self.scratch.output_conv2(m) + fmap = main_logits.permute(0, 2, 3, 1) + depth_pred = _apply_activation(fmap[..., :-1], self.activation) + depth_conf = _apply_activation(fmap[..., -1], self.conf_activation) + + outs = { + self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]), + f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]), + } + + if self.enable_aux: + # Auxiliary "ray" head (multi-level inside) -- only the last level + # is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``: + # each aux pyramid level goes through ``output_conv1_aux[i]`` + # (5-layer conv stack that ends at ``features // 2`` channels), + # then the last level optionally gets a pos-embed and finally + # ``output_conv2_aux[-1]``. + aux_processed = [ + self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr) + ] + last_aux = aux_processed[-1] + if self.pos_embed: + last_aux = _add_pos_embed(last_aux, W, H) + last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux) + fmap_last = last_aux_logits.permute(0, 2, 3, 1) + # Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation. + aux_pred = fmap_last[..., :-1] + aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation) + outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:]) + outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:]) + + return outs diff --git a/comfy/ldm/depth_anything_3/model.py b/comfy/ldm/depth_anything_3/model.py new file mode 100644 index 000000000..f3c8a5ee3 --- /dev/null +++ b/comfy/ldm/depth_anything_3/model.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from typing import Dict, Optional, Sequence + +import torch +import torch.nn as nn + +from comfy.image_encoders.dino2 import Dinov2Model + +from .camera import CameraDec, CameraEnc +from .dpt import DPT, DualDPT +from .ray_pose import get_extrinsic_from_camray +from .transform import affine_inverse, pose_encoding_to_extri_intri + + +_HEAD_REGISTRY = { + "dpt": DPT, + "dualdpt": DualDPT, +} + + +# Backbone presets (mirror the upstream DINOv2 ViT variants). +_BACKBONE_PRESETS = { + "vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False), + "vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False), + "vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False), + "vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True), +} + + +def _build_backbone_config( + backbone_name: str, + *, + alt_start: int, + qknorm_start: int, + rope_start: int, + cat_token: bool, +) -> dict: + if backbone_name not in _BACKBONE_PRESETS: + raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}") + cfg = dict(_BACKBONE_PRESETS[backbone_name]) + cfg.update(dict( + layer_norm_eps=1e-6, + patch_size=14, + image_size=518, + # No mask_token in DA3 weights; omit param to avoid load warnings. + use_mask_token=False, + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + rope_freq=100.0, + )) + return cfg + + +class DepthAnything3Net(nn.Module): + + PATCH_SIZE = 14 + + def __init__( + self, + # --- Backbone --- + backbone_name: str = "vitl", + out_layers: Sequence[int] = (4, 11, 17, 23), + alt_start: int = -1, + qknorm_start: int = -1, + rope_start: int = -1, + cat_token: bool = False, + # --- Head --- + head_type: str = "dpt", # dpt or dualdpt + head_dim_in: int = 1024, + head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf + head_features: int = 256, + head_out_channels: Sequence[int] = (256, 512, 1024, 1024), + head_use_sky_head: bool = True, # ignored by DualDPT + head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT + # --- Camera (multi-view) --- + has_cam_enc: bool = False, + has_cam_dec: bool = False, + cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim) + cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token) + # ComfyUI plumbing + device=None, dtype=None, operations=None, + **_ignored, + ): + super().__init__() + head_cls = _HEAD_REGISTRY[head_type.lower()] + self.head_type = head_type.lower() + self.has_sky = (self.head_type == "dpt") and head_use_sky_head + self.has_conf = head_output_dim > 1 + self.out_layers = list(out_layers) + + backbone_cfg = _build_backbone_config( + backbone_name, + alt_start=alt_start, + qknorm_start=qknorm_start, + rope_start=rope_start, + cat_token=cat_token, + ) + self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations) + + head_kwargs = dict( + dim_in=head_dim_in, + patch_size=self.PATCH_SIZE, + output_dim=head_output_dim, + features=head_features, + out_channels=tuple(head_out_channels), + device=device, dtype=dtype, operations=operations, + ) + if self.head_type == "dpt": + head_kwargs.update( + use_sky_head=head_use_sky_head, + pos_embed=(False if head_pos_embed is None else head_pos_embed), + ) + else: # dualdpt + head_kwargs.update( + pos_embed=(True if head_pos_embed is None else head_pos_embed), + ) + self.head = head_cls(**head_kwargs) + + # Built only if checkpoint has weights; cam_enc output dim == embed_dim. + embed_dim = backbone_cfg["hidden_size"] + if has_cam_enc: + self.cam_enc = CameraEnc( + dim_out=cam_dim_out if cam_dim_out is not None else embed_dim, + num_heads=max(1, embed_dim // 64), + device=device, dtype=dtype, operations=operations, + ) + else: + self.cam_enc = None + if has_cam_dec: + default_dim = embed_dim * (2 if cat_token else 1) + self.cam_dec = CameraDec( + dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim, + device=device, dtype=dtype, operations=operations, + ) + else: + self.cam_dec = None + + self.dtype = dtype + + def forward( + self, + image: torch.Tensor, + extrinsics: Optional[torch.Tensor] = None, + intrinsics: Optional[torch.Tensor] = None, + *, + use_ray_pose: bool = False, + ref_view_strategy: str = "saddle_balanced", + export_feat_layers: Optional[Sequence[int]] = None, + **_unused, + ) -> Dict[str, torch.Tensor]: + """Run depth and optionally pose prediction.""" + if image.ndim == 4: + image = image.unsqueeze(1) # (B, 1, 3, H, W) + assert image.ndim == 5 and image.shape[2] == 3, \ + f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}" + + B, S, _, H, W = image.shape + assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \ + f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}" + + # Camera-token preparation (multi-view path). + cam_token = None + if extrinsics is not None and intrinsics is not None and self.cam_enc is not None: + cam_token = self.cam_enc(extrinsics, intrinsics, (H, W)) + + # Toggle aux ray output on/off depending on what the caller asked for. + if isinstance(self.head, DualDPT): + self.head.enable_aux = bool(use_ray_pose) + + feats, aux_feats = self.backbone.get_intermediate_layers_da3( + image, self.out_layers, cam_token=cam_token, + ref_view_strategy=ref_view_strategy, + export_feat_layers=export_feat_layers, + ) + head_out = self.head(feats, H=H, W=W, patch_start_idx=0) + + # Pose prediction. + out: Dict[str, torch.Tensor] = {} + if use_ray_pose and "ray" in head_out and "ray_conf" in head_out: + ray = head_out["ray"] + ray_conf = head_out["ray_conf"] + extr_c2w, focal, pp = get_extrinsic_from_camray( + ray, ray_conf, ray.shape[-3], ray.shape[-2], + ) + # Match the upstream output: w2c, drop the homogeneous row. + extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :] + # Build pixel-space intrinsics from the normalised focal/pp output. + intr = torch.eye(3, device=ray.device, dtype=ray.dtype) + intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone() + intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W + intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H + intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5 + intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5 + out["extrinsics"] = extr_w2c + out["intrinsics"] = intr + elif self.cam_dec is not None and S > 1: + # Decode the cam-token of the final out_layer into a pose encoding. + cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec) + pose_enc = self.cam_dec(cam_feat) + c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W)) + # Match the upstream output convention: w2c (world->camera), 3x4. + c2w_4x4 = torch.cat([ + c2w_3x4, + torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype) + .view(1, 1, 1, 4).expand(B, S, 1, 4), + ], dim=-2) + out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :] + out["intrinsics"] = intr + + # Flatten the views axis for per-pixel outputs (depth/conf/sky) so the + # per-image consumer keeps its (B*S, H, W) interface. + for k, v in head_out.items(): + if k in ("ray", "ray_conf"): + # Keep multi-view shape for downstream pose work. + out[k] = v + elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S: + out[k] = v.reshape(B * S, *v.shape[2:]) + else: + out[k] = v + + if export_feat_layers: + out["aux_features"] = self._reshape_aux_features(aux_feats, H, W) + return out + + def _reshape_aux_features(self, aux_feats, H: int, W: int): + """Reshape (B, S, N, C) aux features into (B, S, h_p, w_p, C).""" + ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE + out = [] + for f in aux_feats: + B, S, N, C = f.shape + assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}" + out.append(f.reshape(B, S, ph, pw, C)) + return out diff --git a/comfy/ldm/depth_anything_3/preprocess.py b/comfy/ldm/depth_anything_3/preprocess.py new file mode 100644 index 000000000..2238bd0d6 --- /dev/null +++ b/comfy/ldm/depth_anything_3/preprocess.py @@ -0,0 +1,128 @@ +"""Input/output preprocessing helpers for Depth Anything 3.""" + +from __future__ import annotations + +from typing import Tuple + +import torch + +import comfy.utils + +PATCH_SIZE = 14 + +# ImageNet normalization constants used during DA3 training. +_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406]) +_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225]) + + +def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int: + down = (x // patch) * patch + up = down + patch + return up if abs(up - x) <= abs(x - down) else down + + +def compute_target_size(orig_h: int, orig_w: int, process_res: int, method: str = "upper_bound_resize") -> Tuple[int, int]: + """Compute (target_h, target_w) for a single image. + upper_bound_resize: scale longest side to process_res, then round each dim to nearest multiple of 14 (default upstream method). + lower_bound_resize: scale shortest side to process_res, then round.""" + + if method == "upper_bound_resize": + longest = max(orig_h, orig_w) + scale = process_res / float(longest) + elif method == "lower_bound_resize": + shortest = min(orig_h, orig_w) + scale = process_res / float(shortest) + else: + raise ValueError(f"Unsupported process_res_method: {method}") + + new_w = max(1, _round_to_patch(int(round(orig_w * scale)))) + new_h = max(1, _round_to_patch(int(round(orig_h * scale)))) + return new_h, new_w + + +def preprocess_image(image: torch.Tensor, process_res: int = 504, method: str = "upper_bound_resize") -> torch.Tensor: + assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + B, H, W, _ = image.shape + target_h, target_w = compute_target_size(H, W, process_res, method) + + # (B, H, W, 3) -> (B, 3, H, W) + x = image.movedim(-1, 1).contiguous() + if (target_h, target_w) != (H, W): + # Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale). + # Lanczos in ``common_upscale`` is anti-aliased and produces the + # closest pixel-wise match in a sweep across {bilinear, bicubic, + # area, lanczos, bislerp}. Used in both directions for simplicity. + x = comfy.utils.common_upscale(x.float(), target_w, target_h, "lanczos", "disabled",) + x = x.clamp(0.0, 1.0) + + mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1) + x = (x - mean) / std + return x + + +# ----------------------------------------------------------------------------- +# Output post-processing (sky-aware clipping for Mono/Metric variants) +# ----------------------------------------------------------------------------- + + +def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor: + """Boolean mask: True for non-sky pixels (sky probability < threshold).""" + return sky_prediction < threshold + + +def apply_sky_aware_clip(depth: torch.Tensor, sky: torch.Tensor, threshold: float = 0.3, quantile: float = 0.99) -> torch.Tensor: + """Clips sky regions to the 99th percentile of non-sky depth. Returns a new depth tensor.""" + non_sky = compute_non_sky_mask(sky, threshold=threshold) + if non_sky.sum() <= 10 or (~non_sky).sum() <= 10: + return depth.clone() + + non_sky_depth = depth[non_sky] + if non_sky_depth.numel() > 100_000: + idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device) + sampled = non_sky_depth[idx] + else: + sampled = non_sky_depth + + max_depth = torch.quantile(sampled, quantile) + out = depth.clone() + out[~non_sky] = max_depth + return out + + +def normalize_depth_v2_style(depth: torch.Tensor, sky: torch.Tensor | None = None, low_quantile: float = 0.01, high_quantile: float = 0.99) -> torch.Tensor: + """V2-style normalization computes percentile bounds over non-sky pixels (when available), then maps depth into [0, 1] with near = white (1.0).""" + if sky is not None: + mask = compute_non_sky_mask(sky) + if mask.any(): + valid = depth[mask] + else: + valid = depth.flatten() + else: + valid = depth.flatten() + + if valid.numel() > 100_000: + idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device) + sample = valid[idx] + else: + sample = valid + + lo = torch.quantile(sample, low_quantile) + hi = torch.quantile(sample, high_quantile) + rng = (hi - lo).clamp(min=1e-6) + norm = ((depth - lo) / rng).clamp(0.0, 1.0) + # Nearer pixels are brighter (1.0) + norm = 1.0 - norm + if sky is not None: + # Sky pixels become black (far / unknown) + sky_mask = ~compute_non_sky_mask(sky) + norm = torch.where(sky_mask, torch.zeros_like(norm), norm) + return norm + + +def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor: + """Simple per-frame min/max normalization with near=1.0 convention.""" + lo = depth.amin(dim=(-2, -1), keepdim=True) + hi = depth.amax(dim=(-2, -1), keepdim=True) + rng = (hi - lo).clamp(min=1e-6) + return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0) diff --git a/comfy/ldm/depth_anything_3/ray_pose.py b/comfy/ldm/depth_anything_3/ray_pose.py new file mode 100644 index 000000000..90890f1da --- /dev/null +++ b/comfy/ldm/depth_anything_3/ray_pose.py @@ -0,0 +1,272 @@ +"""Ray-to-pose conversion for the multi-view path of Depth Anything 3.""" + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch + + +# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops. + + +def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Decompose A = Q @ L with Q orthogonal and L lower-triangular. + Implemented in terms of QR by reversing the columns/rows; the standard + trick from the upstream reference. Inputs A are (3, 3).""" + P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device, dtype=A.dtype) + A_tilde = A @ P + # CUDA QR is not implemented for fp16/bf16; upcast just for this call. + Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float()) + Q_tilde = Q_tilde.to(A.dtype) + R_tilde = R_tilde.to(A.dtype) + Q = Q_tilde @ P + L = P @ R_tilde @ P + d = torch.diag(L) + sign = torch.sign(d) + Q = Q * sign[None, :] # scale columns of Q + L = L * sign[:, None] # scale rows of L + return Q, L + + +def _homogenize_points(points: torch.Tensor) -> torch.Tensor: + return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + + +# ----------------------------------------------------------------------------- +# Weighted-LSQ + RANSAC homography (batched) +# ----------------------------------------------------------------------------- + + +def _find_homography_weighted_lsq(src_pts: torch.Tensor, dst_pts: torch.Tensor, confident_weight: torch.Tensor,) -> torch.Tensor: + """Solve a single H with weighted least-squares (DLT).""" + N = src_pts.shape[0] + if N < 4: + raise ValueError("At least 4 points are required to compute a homography.") + w = confident_weight.sqrt().unsqueeze(1) # (N, 1) + x = src_pts[:, 0:1] + y = src_pts[:, 1:2] + u = dst_pts[:, 0:1] + v = dst_pts[:, 1:2] + zeros = torch.zeros_like(x) + A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1) + A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1) + A = torch.cat([A1, A2], dim=0) # (2N, 9) + # CUDA SVD is not implemented for fp16/bf16; upcast just for this call. + _, _, Vh = torch.linalg.svd(A.float()) + Vh = Vh.to(A.dtype) + H = Vh[-1].reshape(3, 3) + return H / H[-1, -1] + + +def _find_homography_weighted_lsq_batched(src_pts_batch: torch.Tensor, dst_pts_batch: torch.Tensor, confident_weight_batch: torch.Tensor) -> torch.Tensor: + """Batched DLT solver. Inputs (B, K, 2) / (B, K); output (B, 3, 3).""" + B, K, _ = src_pts_batch.shape + w = confident_weight_batch.sqrt().unsqueeze(2) + x = src_pts_batch[:, :, 0:1] + y = src_pts_batch[:, :, 1:2] + u = dst_pts_batch[:, :, 0:1] + v = dst_pts_batch[:, :, 1:2] + zeros = torch.zeros_like(x) + A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2) + A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2) + A = torch.cat([A1, A2], dim=1) # (B, 2K, 9) + # CUDA SVD is not implemented for fp16/bf16; upcast just for this call. + _, _, Vh = torch.linalg.svd(A.float()) + Vh = Vh.to(A.dtype) + H = Vh[:, -1].reshape(B, 3, 3) + return H / H[:, 2:3, 2:3] + + +def _ransac_find_homography_weighted_batched( + src_pts: torch.Tensor, # (B, N, 2) + dst_pts: torch.Tensor, # (B, N, 2) + confident_weight: torch.Tensor, # (B, N) + n_sample: int, + n_iter: int = 100, + reproj_threshold: float = 3.0, + num_sample_for_ransac: int = 8, + max_inlier_num: int = 10000, + rand_sample_iters_idx: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Batched weighted-RANSAC homography estimator. Returns (B, 3, 3) homography matrices.""" + B, N, _ = src_pts.shape + assert N >= 4 + device = src_pts.device + + sorted_idx = torch.argsort(confident_weight, descending=True, dim=1) + candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample) + + if rand_sample_iters_idx is None: + rand_sample_iters_idx = torch.stack( + [torch.randperm(n_sample, device=device)[:num_sample_for_ransac] + for _ in range(n_iter)], + dim=0, + ) + + rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k) + b_idx = ( + torch.arange(B, device=device) + .view(B, 1, 1) + .expand(B, n_iter, num_sample_for_ransac) + ) + src_b = src_pts[b_idx, rand_idx] + dst_b = dst_pts[b_idx, rand_idx] + w_b = confident_weight[b_idx, rand_idx] + + cB, cN = src_b.shape[:2] + H_batch = _find_homography_weighted_lsq_batched( + src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1), + ).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3) + + src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2) + proj = torch.bmm( + src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3), + H_batch.reshape(-1, 3, 3).transpose(1, 2), + ) # (B*n_iter, N, 3) + proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2) + err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N) + inlier_mask = err < reproj_threshold + score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2) + best_idx = torch.argmax(score, dim=1) + best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx] + + # Refit with the inlier set (per-batch, since the inlier counts vary). + H_inlier_list = [] + for b in range(B): + mask = best_inlier_mask[b] + in_src = src_pts[b][mask] + in_dst = dst_pts[b][mask] + in_w = confident_weight[b][mask] + if in_src.shape[0] < 4: + # Fall back to identity when RANSAC fails to find enough inliers. + H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype)) + continue + sorted_w = torch.argsort(in_w, descending=True) + if len(sorted_w) > max_inlier_num: + keep = max(int(len(sorted_w) * 0.95), max_inlier_num) + sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]] + H_inlier_list.append( + _find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w]) + ) + return torch.stack(H_inlier_list, dim=0) + + +# ----------------------------------------------------------------------------- +# Camera-ray utilities +# ----------------------------------------------------------------------------- + + +def _unproject_identity(num_y: int, num_x: int, B: int, S: int, device, dtype) -> torch.Tensor: + """Camera-space unit rays for an identity intrinsic on a 2x2 image plane.""" + dx = 1.0 / num_x + dy = 1.0 / num_y + # Centered camera-space coords directly (skip the K^-1 step since it's + # just a translation by -1 on x and y when K is identity-with-center=1). + y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype) + x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype) + yy, xx = torch.meshgrid(y, x, indexing="ij") + grid = torch.stack((xx, yy), dim=-1) # (h, w, 2) + grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2) + return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1) + + +def _camray_to_caminfo( + camray: torch.Tensor, # (B, S, h, w, 6) + confidence: Optional[torch.Tensor] = None, # (B, S, h, w) + reproj_threshold: float = 0.2, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert per-pixel camera rays to per-view (R, T, focal, principal).""" + if confidence is None: + confidence = torch.ones_like(camray[..., 0]) + B, S, h, w, _ = camray.shape + device = camray.device + dtype = camray.dtype + + rays_target = camray[..., :3] # (B, S, h, w, 3) + rays_origin = _unproject_identity(h, w, B, S, device, dtype) + + # Flatten (B*S, h*w, *) for the RANSAC routine. + rays_target = rays_target.flatten(0, 1).flatten(1, 2) + rays_origin = rays_origin.flatten(0, 1).flatten(1, 2) + weights = confidence.flatten(0, 1).flatten(1, 2).clone() + + # Project to 2D in homogeneous form (the upstream calls this "perspective division"). + z_thresh = 1e-4 + mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh) + weights = torch.where(mask, weights, torch.zeros_like(weights)) + src = rays_origin.clone() + dst = rays_target.clone() + src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0]) + src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1]) + dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0]) + dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1]) + src = src[..., :2] + dst = dst[..., :2] + + N = src.shape[1] + n_iter = 100 + sample_ratio = 0.3 + num_sample_for_ransac = 8 + n_sample = max(num_sample_for_ransac, int(N * sample_ratio)) + rand_idx = torch.stack( + [torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)], + dim=0, + ) + + # Chunk along the view axis to keep peak memory predictable. + chunk = 2 + A_list = [] + for i in range(0, src.shape[0], chunk): + A = _ransac_find_homography_weighted_batched( + src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk], + n_sample=n_sample, n_iter=n_iter, + num_sample_for_ransac=num_sample_for_ransac, + reproj_threshold=reproj_threshold, + rand_sample_iters_idx=rand_idx, + max_inlier_num=8000, + ) + # Flip sign on dets that come out < 0 (so that the QL produces a + # right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so + # do the comparison in fp32. + flip = torch.linalg.det(A.float()) < 0 + A = torch.where(flip[:, None, None], -A, A) + A_list.append(A) + A = torch.cat(A_list, dim=0) # (B*S, 3, 3) + + R_list, f_list, pp_list = [], [], [] + for i in range(A.shape[0]): + R, L = _ql_decomposition(A[i]) + L = L / L[2][2] + f_list.append(torch.stack((L[0][0], L[1][1]))) + pp_list.append(torch.stack((L[2][0], L[2][1]))) + R_list.append(R) + R = torch.stack(R_list).reshape(B, S, 3, 3) + focal = torch.stack(f_list).reshape(B, S, 2) + pp = torch.stack(pp_list).reshape(B, S, 2) + + # Translation: confidence-weighted average of camray direction(s). + cf = confidence.flatten(0, 1).flatten(1, 2) + T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1) + T = T / cf.sum(dim=-1, keepdim=True) + T = T.reshape(B, S, 3) + + # Match upstream output convention: focal -> 1/focal, pp + 1. + return R, T, 1.0 / focal, pp + 1.0 + + +def get_extrinsic_from_camray( + camray: torch.Tensor, # (B, S, h, w, 6) + conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w) + patch_size_y: int, + patch_size_x: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Wrap a 4x4 extrinsic + per-view focal + principal-point output.""" + if conf.ndim == 5 and conf.shape[-1] == 1: + conf = conf.squeeze(-1) + R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf) + extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4) + homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device) + homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4) + extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4) + return extr, focal, pp diff --git a/comfy/ldm/depth_anything_3/reference_view_selector.py b/comfy/ldm/depth_anything_3/reference_view_selector.py new file mode 100644 index 000000000..90f00be92 --- /dev/null +++ b/comfy/ldm/depth_anything_3/reference_view_selector.py @@ -0,0 +1,87 @@ +"""Reference-view selection for the multi-view path of Depth Anything 3.""" + +from __future__ import annotations + +from typing import Literal + +import torch + + +RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"] + + +# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``. +# Reference selection only runs when there are at least this many views. +THRESH_FOR_REF_SELECTION: int = 3 + + +def select_reference_view(x: torch.Tensor, strategy: RefViewStrategy = "saddle_balanced") -> torch.Tensor: + """Pick a reference view index per batch element.""" + B, S, _, _ = x.shape + if S <= 1: + return torch.zeros(B, dtype=torch.long, device=x.device) + if strategy == "first": + return torch.zeros(B, dtype=torch.long, device=x.device) + if strategy == "middle": + return torch.full((B,), S // 2, dtype=torch.long, device=x.device) + + # Feature-based strategies: normalised cls/cam token per view. + img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C) + + if strategy == "saddle_balanced": + sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S) + sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) + sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S) + feat_norm = x[:, :, 0].norm(dim=-1) # (B,S) + feat_var = img_class_feat.var(dim=-1) # (B,S) + + def _normalize(metric): + mn = metric.min(dim=1, keepdim=True).values + mx = metric.max(dim=1, keepdim=True).values + return (metric - mn) / (mx - mn + 1e-8) + + sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var) + balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs() + return balance.argmin(dim=1) + + if strategy == "saddle_sim_range": + sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) + sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0) + sim_max = sim_no_diag.max(dim=-1).values + sim_min = sim_no_diag.min(dim=-1).values + return (sim_max - sim_min).argmax(dim=1) + + raise ValueError( + f"Unknown reference view selection strategy: {strategy!r}. " + f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'" + ) + + +def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor: + """Reorder x so the reference view is at position 0 in axis S.""" + B, S = x.shape[0], x.shape[1] + if S <= 1: + return x + positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) + b_idx_exp = b_idx.unsqueeze(1) + reorder = torch.where( + (positions > 0) & (positions <= b_idx_exp), + positions - 1, + positions, + ) + reorder[:, 0] = b_idx + batch = torch.arange(B, device=x.device).unsqueeze(1) + return x[batch, reorder] + + +def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor: + """Inverse of reorder_by_reference.""" + B, S = x.shape[0], x.shape[1] + if S <= 1: + return x + target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1) + b_idx_exp = b_idx.unsqueeze(1) + restore = torch.where(target_positions < b_idx_exp, target_positions + 1, target_positions) + restore = torch.scatter(restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp)) + batch = torch.arange(B, device=x.device).unsqueeze(1) + return x[batch, restore] diff --git a/comfy/ldm/depth_anything_3/transform.py b/comfy/ldm/depth_anything_3/transform.py new file mode 100644 index 000000000..b735d7bec --- /dev/null +++ b/comfy/ldm/depth_anything_3/transform.py @@ -0,0 +1,160 @@ +"""Geometry / camera transform helpers for Depth Anything 3.""" + +from __future__ import annotations + +from typing import Tuple + +import torch +import torch.nn.functional as F + + +# ----------------------------------------------------------------------------- +# Affine 4x4 helpers +# ----------------------------------------------------------------------------- + + +def as_homogeneous(ext: torch.Tensor) -> torch.Tensor: + """Promote (...,3,4) extrinsics to (...,4,4) homogeneous form. No-op when the input is already ``(...,4,4)``.""" + if ext.shape[-2:] == (4, 4): + return ext + if ext.shape[-2:] == (3, 4): + ones = torch.zeros_like(ext[..., :1, :4]) + ones[..., 0, 3] = 1.0 + return torch.cat([ext, ones], dim=-2) + raise ValueError(f"Invalid affine shape: {ext.shape}") + + +def affine_inverse(A: torch.Tensor) -> torch.Tensor: + """Inverse of an affine matrix ``[R|T; 0 0 0 1]``.""" + R = A[..., :3, :3] + T = A[..., :3, 3:] + P = A[..., 3:, :] + return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2) + + +# ----------------------------------------------------------------------------- +# Quaternion <-> rotation matrix (xyzw / scalar-last) +# ----------------------------------------------------------------------------- + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """sqrt(max(0, x)) with a zero subgradient where x == 0.""" + ret = torch.zeros_like(x) + positive_mask = x > 0 + if torch.is_grad_enabled(): + ret[positive_mask] = torch.sqrt(x[positive_mask]) + else: + ret = torch.where(positive_mask, torch.sqrt(x), ret) + return ret + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """Force the real part of a unit quaternion (xyzw) to be non-negative.""" + return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions) + + +def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor: + """Convert quaternions (xyzw) to (...,3,3) rotation matrices.""" + i, j, k, r = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor: + """Convert (...,3,3) rotation matrices to quaternions (xyzw).""" + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape( + batch_dim + (4,) + ) + # Reorder rijk -> xyzw (i.e. ijkr). + out = out[..., [1, 2, 3, 0]] + return standardize_quaternion(out) + + +# ----------------------------------------------------------------------------- +# Pose-encoding <-> extrinsics + intrinsics +# ----------------------------------------------------------------------------- + + +def extri_intri_to_pose_encoding(extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor: + """Pack (extr, intr, image_size) into the 9-D pose-encoding vector. + extrinsics: camera-to-world (c2w) (B,S,4,4) matrices, + intrinsics: pixel-space (B,S,3,3) matrices, + image_size_hw: is a (H, W) pair. + """ + R = extrinsics[..., :3, :3] + T = extrinsics[..., :3, 3] + quat = mat_to_quat(R) + H, W = image_size_hw + fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1]) + fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0]) + return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float() + + +def pose_encoding_to_extri_intri(pose_encoding: torch.Tensor, image_size_hw: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]: + """Inverse of extri_intri_to_pose_encoding.""" + T = pose_encoding[..., :3] + quat = pose_encoding[..., 3:7] + fov_h = pose_encoding[..., 7] + fov_w = pose_encoding[..., 8] + # Normalize to unit quaternion. CameraDec outputs raw values; a near-zero + # quaternion causes two_s = 2/norm² → inf in quat_to_mat → NaN extrinsics. + quat = quat / quat.norm(dim=-1, keepdim=True).clamp(min=1e-6) + R = quat_to_mat(quat) + extrinsics = torch.cat([R, T[..., None]], dim=-1) + H, W = image_size_hw + fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6) + fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6) + intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device, dtype=pose_encoding.dtype) + intrinsics[..., 0, 0] = fx + intrinsics[..., 1, 1] = fy + intrinsics[..., 0, 2] = W / 2 + intrinsics[..., 1, 2] = H / 2 + intrinsics[..., 2, 2] = 1.0 + return extrinsics, intrinsics diff --git a/comfy/ldm/ideogram4/model.py b/comfy/ldm/ideogram4/model.py index 3b02a243a..4ea5b8aaf 100644 --- a/comfy/ldm/ideogram4/model.py +++ b/comfy/ldm/ideogram4/model.py @@ -106,11 +106,11 @@ class Ideogram4EmbedScalar(nn.Module): self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device) self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device) - def forward(self, x): + def forward(self, x, dtype): x = x.to(torch.float32) scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min) emb = _sinusoidal_embedding(scaled, self.dim) - emb = emb.to(self.mlp_in.weight.dtype) + emb = emb.to(dtype) emb = F.silu(self.mlp_in(emb)) return self.mlp_out(emb) @@ -161,7 +161,7 @@ class Ideogram4Transformer(nn.Module): x = x * output_image_mask h = self.input_proj(x) * output_image_mask - t_cond = self.t_embedding(t) + t_cond = self.t_embedding(t, dtype=x.dtype) if t.dim() == 1: t_cond = t_cond.unsqueeze(1) adaln_input = F.silu(self.adaln_proj(t_cond)) @@ -174,7 +174,7 @@ class Ideogram4Transformer(nn.Module): llm = self.llm_cond_proj(llm) * text_mask h[:, :L_text] = h[:, :L_text] + llm - h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long)) + h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype) # Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch). freqs_cis = precompute_freqs_cis( @@ -235,7 +235,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer): def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options): B = x_chunk.shape[0] device = x_chunk.device - img_tokens = self._img_to_tokens(x_chunk).to(self.dtype) + img_tokens = self._img_to_tokens(x_chunk) L_img = img_tokens.shape[1] L_text = context_chunk.shape[1] L = L_text + L_img @@ -268,7 +268,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer): def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options): B = x_chunk.shape[0] device = x_chunk.device - img_tokens = self._img_to_tokens(x_chunk).to(self.dtype) + img_tokens = self._img_to_tokens(x_chunk) L_img = img_tokens.shape[1] position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3) diff --git a/comfy/ldm/omnigen/omnigen2.py b/comfy/ldm/omnigen/omnigen2.py index 82edc92da..e9ca5229d 100644 --- a/comfy/ldm/omnigen/omnigen2.py +++ b/comfy/ldm/omnigen/omnigen2.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from comfy.ldm.lightricks.model import Timesteps from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.modules.attention import optimized_attention_masked import comfy.model_management import comfy.ldm.common_dit @@ -17,9 +18,7 @@ def apply_rotary_emb(x, freqs_cis): if x.shape[1] == 0: return x - t_ = x.reshape(*x.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x.shape).to(dtype=x.dtype) + return apply_rope1(x, freqs_cis) def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 3462d8108..e49886dd9 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -51,6 +51,18 @@ class FeedForward(nn.Module): return hidden_states +# Addin this back because Nunchaku custom nodes rely on it, see comment here: +# https://github.com/Comfy-Org/ComfyUI/pull/14178#issuecomment-4640475161 +# TODO: Eventually remove this once we natively support SVDQuants +def apply_rotary_emb(x, freqs_cis): + if x.shape[1] == 0: + return x + + t_ = x.reshape(*x.shape[:-1], -1, 1, 2) + t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] + return t_out.reshape(*x.shape) + + class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 70dfe7b16..282408891 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -8,7 +8,7 @@ from einops import rearrange from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.layers import EmbedND -from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.flux.math import apply_rope1, rope import comfy.ldm.common_dit import comfy.model_management import comfy.patcher_extension @@ -570,6 +570,14 @@ class WanModel(torch.nn.Module): full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2) x = torch.concat((full_ref, x), dim=1) + # In-context reference (Bernini) + context_latents = kwargs.get("context_latents", None) + main_len = x.shape[1] + if context_latents is not None: + for lat in context_latents: + cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2) + x = torch.cat([x, cl], dim=1) + # context context = self.text_embedding(context) @@ -599,6 +607,9 @@ class WanModel(torch.nn.Module): # head x = self.head(x, e) + if context_latents is not None: + x = x[:, :main_len] + if full_ref is not None: x = x[:, full_ref.shape[1]:] @@ -606,7 +617,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0): patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) @@ -638,6 +649,13 @@ class WanModel(torch.nn.Module): img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) freqs = self.rope_embedder(img_ids).movedim(1, 2) + + # In-context reference: a non-zero source_id composes an extra rotation into the spatial rope + if source_id: + d = self.dim // self.num_heads + pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32) + id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype) + freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot) return freqs def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): @@ -661,6 +679,15 @@ class WanModel(torch.nn.Module): t_len += 1 freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) + + # In-context reference: one rope block per stream, each with it's own source_id (1, 2, ...) to distinguish from the target (id 0). + context_latents = kwargs.get("context_latents", None) + if context_latents is not None: + context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents] + for i, lat in enumerate(context_latents): + freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1) + kwargs = {**kwargs, "context_latents": context_latents} + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): @@ -1631,13 +1658,15 @@ class SCAILWanModel(WanModel): self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) - def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs): if reference_latent is not None: x = torch.cat((reference_latent, x), dim=2) # embeddings x = self.patch_embedding(x.float()).to(x.dtype) + if ref_mask_latents is not None: # SCAIL-2 additive mask stream + x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype) grid_sizes = x.shape[2:] transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) @@ -1645,6 +1674,8 @@ class SCAILWanModel(WanModel): scail_pose_seq_len = 0 if pose_latents is not None: scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + if sam_latents is not None: # SCAIL-2 additive mask stream + scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype) scail_x = scail_x.flatten(2).transpose(1, 2) scail_pose_seq_len = scail_x.shape[1] x = torch.cat([x, scail_x], dim=1) @@ -1695,7 +1726,36 @@ class SCAILWanModel(WanModel): return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): + # ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode, + # which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset. + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}): + if ref_mask_flag is not None and not bool(ref_mask_flag): + REF_ROPE_H = 120.0 + POSE_ROPE_W = 120.0 + + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + main_t_patches = t - ref_t_patches + + parts = [] + if ref_t_patches > 0: + ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}} + parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf)) + if main_t_patches > 0: + parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options)) + + if pose_latents is not None: + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + h_scale = h / H_pose + w_scale = w / W_pose + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf)) + + return torch.cat(parts, dim=1) + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) if pose_latents is None: @@ -1719,12 +1779,16 @@ class SCAILWanModel(WanModel): return torch.cat([main_freqs, pose_freqs], dim=1) - def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs): bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) if pose_latents is not None: pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) + if ref_mask_latents is not None: # SCAIL-2 + ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size) + if sam_latents is not None: # SCAIL-2 + sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size) t_len = t if time_dim_concat is not None: @@ -1737,5 +1801,15 @@ class SCAILWanModel(WanModel): reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) t_len += reference_latent.shape[2] - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) - return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] + ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2 + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w] + + +class SCAIL2WanModel(SCAILWanModel): + """SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream.""" + + def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) + self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) diff --git a/comfy/lora.py b/comfy/lora.py index 4e0ea29e0..2c8d0f0bf 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -357,6 +357,12 @@ def model_lora_keys_unet(model, key_map={}): key_lora = k[len("diffusion_model."):-len(".weight")] key_map["transformer.{}".format(key_lora)] = k + if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 042804771..ab4a11022 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -65,6 +65,7 @@ import comfy.ldm.ernie.model import comfy.ldm.sam3.detector import comfy.ldm.hidream_o1.model from comfy.ldm.hidream_o1.conditioning import build_extra_conds +import comfy.ldm.depth_anything_3.model import comfy.model_management import comfy.patcher_extension @@ -1518,8 +1519,26 @@ class WAN21(BaseModel): if reference_latents is not None: out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0]) + # In-context reference conditioning (Bernini) + context_latents = kwargs.get("context_latents", None) + if context_latents is not None: + out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents]) + return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + # In-context cond slicing (Bernini) + if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list): + dim = window.dim + out = [] + for lat in cond_value.cond: + if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]: + out.append(window.get_tensor(lat, device, dim=dim, retain_index_list=retain_index_list)) + else: + out.append(lat.to(device)) + return cond_value._copy_with(out) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_CausalAR(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): @@ -1754,6 +1773,97 @@ class WAN21_SCAIL(WAN21): return out +class WAN21_SCAIL2(WAN21_SCAIL): + """SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream.""" + + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents") + self.memory_usage_shape_process = { + "pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]], + "sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]], + } + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + driving_mask_28ch = kwargs.get("driving_mask_28ch", None) + if driving_mask_28ch is not None: + out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous()) + + ref_mask_28ch = kwargs.get("ref_mask_28ch", None) + if ref_mask_28ch is not None: + out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous()) + + ref_mask_flag = kwargs.get("ref_mask_flag", None) + if ref_mask_flag is not None: + out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag) + + return out + + def extra_conds_shapes(self, **kwargs): + out = super().extra_conds_shapes(**kwargs) + driving_mask_28ch = kwargs.get("driving_mask_28ch", None) + if driving_mask_28ch is not None: + s = driving_mask_28ch.shape + out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]] + ref_mask_28ch = kwargs.get("ref_mask_28ch", None) + if ref_mask_28ch is not None: + s = ref_mask_28ch.shape + out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]] + return out + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key in ("sam_latents", "pose_latents"): + # Return sliced view omitting retain_index_list + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0) + if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + # The ref mask is just a single frame padded with frames of zeros, so just grab the first frames for all windows + full_ref_mask = cond_value.cond + video_frame_count = x_in.shape[2] + if full_ref_mask.shape[2] != video_frame_count + 1: + return None + window_length = len(window.index_list) + + # Account for the causal anchor frame if it exists + anchor_index = getattr(window, "causal_anchor_index", None) + if anchor_index is not None and anchor_index >= 0: + window_length += 1 + + window_ref_mask = full_ref_mask[:, :, :window_length + 1].to(device) + return cond_value._copy_with(window_ref_mask) + + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + + def concat_cond(self, **kwargs): + # The 4 extra channels are the history_mask (1 at clean-anchor frames). + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1] + if extra_channels != 4: + return super().concat_cond(**kwargs) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + return torch.zeros_like(noise)[:, :4] + + device = kwargs["device"] + if mask.shape[1] != 4: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + if mask.shape[1] == 1: + mask = mask.repeat(1, 4, 1, 1, 1) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + return mask + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + # Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image. + return latent_image + + class WAN22_WanDancer(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel) @@ -2227,6 +2337,12 @@ class RT_DETR_v4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) + +class DepthAnything3(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net) + class ErnieImage(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 74c838d13..7d0cab308 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -630,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "humo" elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "animate" + elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "scail2" elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "scail" elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys: @@ -860,6 +862,95 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0] return dit_config + # Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py) + if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "DepthAnything3" + + patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)] + embed_dim = patch_w.shape[0] + depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.') + + # Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg). + backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim) + if backbone_name is None: + return None + dit_config["backbone_name"] = backbone_name + + # Detect DA3 extensions on top of vanilla DINOv2. + has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys + # qk-norm shows up as `attention.q_norm.weight` on enabled blocks. + qknorm_indices = [ + i for i in range(depth) + if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys + ] + qknorm_start = qknorm_indices[0] if qknorm_indices else -1 + + # The DA3 main-series configs always set alt_start == qknorm_start == rope_start. + # cat_token=True is implied by the presence of camera_token. + if has_camera_token: + dit_config["alt_start"] = qknorm_start + dit_config["rope_start"] = qknorm_start + dit_config["qknorm_start"] = qknorm_start + dit_config["cat_token"] = True + else: + dit_config["alt_start"] = -1 + dit_config["rope_start"] = -1 + dit_config["qknorm_start"] = -1 + dit_config["cat_token"] = False + + # Detect head type and config. + has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys + dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1] + dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0] + dit_config["head_out_channels"] = [ + state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0] + for i in range(4) + ] + if has_aux: + # DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width). + dit_config["head_type"] = "dualdpt" + dit_config["head_output_dim"] = 2 + dit_config["head_use_sky_head"] = False + else: + dit_config["head_type"] = "dpt" + dit_config["head_output_dim"] = state_dict[ + '{}head.scratch.output_conv2.2.weight'.format(key_prefix) + ].shape[0] + dit_config["head_use_sky_head"] = ( + '{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys + ) + + # out_layers: hard-coded per upstream YAML config (depth-aware default). + if depth >= 24: + # vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT). + if has_aux: + dit_config["out_layers"] = [11, 15, 19, 23] + else: + dit_config["out_layers"] = [4, 11, 17, 23] + else: + # vits/vitb: 12 blocks + dit_config["out_layers"] = [5, 7, 9, 11] + + # Camera encoder/decoder presence (multi-view + pose path). + has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys + has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys + dit_config["has_cam_enc"] = has_cam_enc + dit_config["has_cam_dec"] = has_cam_dec + if has_cam_enc: + cam_enc_w = state_dict.get( + '{}cam_enc.pose_branch.fc2.weight'.format(key_prefix) + ) + if cam_enc_w is not None: + dit_config["cam_dim_out"] = cam_enc_w.shape[0] + if has_cam_dec: + cam_dec_w = state_dict.get( + '{}cam_dec.fc_t.weight'.format(key_prefix) + ) + if cam_dec_w is not None: + dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1] + return dit_config + if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image dit_config = {} dit_config["image_model"] = "ernie" diff --git a/comfy/model_management.py b/comfy/model_management.py index dfd58bf1b..b15d08ba1 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -534,8 +534,10 @@ try: except: pass -if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast: - torch.backends.cudnn.benchmark = True + +def set_cudnn_benchmark(): + if torch.cuda.is_available() and torch.backends.cudnn.is_available(): + torch.backends.cudnn.benchmark = PerformanceFeature.AutoTune in args.fast try: if torch_version_numeric >= (2, 5): @@ -641,6 +643,8 @@ def free_pins(size, evict_active=False): return freed_total def ensure_pin_budget(size, evict_active=False): + if args.high_ram: + return True if args.fast_disk: shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY else: @@ -651,8 +655,7 @@ def ensure_pin_budget(size, evict_active=False): to_free = shortfall + PIN_PRESSURE_HYSTERESIS return free_pins(to_free, evict_active=evict_active) >= shortfall -def ensure_pin_registerable(size, evict_active=True): - shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY +def free_registrations(shortfall, evict_active=True): if MAX_PINNED_MEMORY <= 0: return False if shortfall <= 0: @@ -674,6 +677,9 @@ def ensure_pin_registerable(size, evict_active=True): return True return shortfall <= REGISTERABLE_PIN_HYSTERESIS +def ensure_pin_registerable(size, evict_active=True): + return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active) + class LoadedModel: def __init__(self, model: ModelPatcher): self._set_model(model) @@ -956,8 +962,6 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False - reset_cast_buffers() - for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): @@ -1494,6 +1498,8 @@ if not args.disable_pinned_memory: PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) def pinned_hostbuf_size(size): + if args.high_ram: + return max(0, int(size * 2)) return max(0, int(min(size, MAX_PINNED_MEMORY) * 2)) def discard_cuda_async_error(): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b716a69e2..d70b42bf8 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -379,10 +379,11 @@ class ModelPatcher: def get_clone_model_override(self): return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) - def clone(self, disable_dynamic=False, model_override=None): + def clone(self, disable_dynamic=False, model_override=None, force_deepcopy=False): class_ = self.__class__ - if self.is_dynamic() and disable_dynamic: - class_ = ModelPatcher + if self.is_dynamic() and disable_dynamic or force_deepcopy: + if self.is_dynamic() and disable_dynamic: + class_ = ModelPatcher if model_override is None: if self.cached_patcher_init is None: raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") diff --git a/comfy/ops.py b/comfy/ops.py index 119177c37..3f088a962 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -180,7 +180,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if pin is not None: cast_maybe_lowvram_patch([pin], dest, offload_stream) return - if signature is None: + if signature is None or args.high_ram: comfy.pinned_memory.pin_memory(m, subset=subset, size=size) pin = comfy.pinned_memory.get_pin(m, subset=subset) cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest) @@ -299,21 +299,21 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of non_blocking = comfy.model_management.device_supports_non_blocking(device) - if hasattr(s, "_v"): + if hasattr(s, "_v") and comfy.model_management.is_device_cpu(device): #vbar doesn't support CPU weights, but some custom nodes have weird paths #that might switch the layer to the CPU and expect it to work. We have to take #a clone conservatively as we are mmapped and some SFT files are packed misaligned #If you are a custom node author reading this, please move your layer to the GPU #or declare your ModelPatcher as CPU in the first place. - if comfy.model_management.is_device_cpu(device): - materialize_meta_param(s, ["weight", "bias"]) - weight = s.weight.to(dtype=dtype, copy=True) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None - return format_return((weight, bias, (None, None, None)), offloadable) + materialize_meta_param(s, ["weight", "bias"]) + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None + return format_return((weight, bias, (None, None, None)), offloadable) + elif hasattr(s, "_v") and s.weight.device != device: prefetched = hasattr(s, "_prefetch") offload_stream = None offload_device = None diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index ffe12e0dc..cb77c517a 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -89,13 +89,26 @@ def pin_memory(module, subset="weights", size=None): not comfy.model_management.ensure_pin_registerable(registerable_size)): return _steal_pin(module, stack, buckets, size, priority) + extended = False try: - hostbuf.extend(size=size) + hostbuf.extend(size=size, register=False) + extended = True + pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] + pin.untyped_storage()._comfy_hostbuf = hostbuf + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + comfy.model_management.free_registrations(size) + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + del pin + hostbuf.truncate(offset, do_unregister=False) + return _steal_pin(module, stack, buckets, size, priority) except RuntimeError: + if extended: + hostbuf.truncate(offset, do_unregister=False) return _steal_pin(module, stack, buckets, size, priority) - module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] - module._pin.untyped_storage()._comfy_hostbuf = hostbuf + module._pin = pin stack.append((module, offset)) module._pin_registered = True module._pin_stack_index = len(stack) - 1 diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 7cf9c133b..3be935577 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1450,6 +1450,17 @@ class WAN21_SCAIL(WAN21_T2V): out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) return out + +class WAN21_SCAIL2(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "scail2", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device) + return out + class WAN22_WanDancer(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -2045,6 +2056,23 @@ class RT_DETR_v4(supported_models_base.BASE): return None +class DepthAnything3(supported_models_base.BASE): + unet_config = { + "image_model": "DepthAnything3", + } + + # Mono path: no num_heads / num_head_channels needed. + unet_extra_config = {} + + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.DepthAnything3(self, device=device) + + def clip_target(self, state_dict={}): + return None + + class ErnieImage(supported_models_base.BASE): unet_config = { "image_model": "ernie", @@ -2259,6 +2287,7 @@ models = [ WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, + WAN21_SCAIL2, WAN22_WanDancer, Hunyuan3Dv2mini, Hunyuan3Dv2, @@ -2286,4 +2315,5 @@ models = [ CogVideoX_I2V, CogVideoX_T2V, SVD_img2vid, + DepthAnything3, ] diff --git a/comfy/text_encoders/ideogram4.py b/comfy/text_encoders/ideogram4.py index 55e655d67..84243772d 100644 --- a/comfy/text_encoders/ideogram4.py +++ b/comfy/text_encoders/ideogram4.py @@ -32,7 +32,9 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer): self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): - if llama_template is None: + if text.startswith('<|im_start|>'): + llama_text = text + elif llama_template is None: llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 8fff52c16..e2e99521f 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -27,10 +27,13 @@ class VideoInput(ABC): path: Union[str, IO[bytes]], format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None + metadata: Optional[dict] = None, + bit_depth: int | None = None, ): """ Abstract method to save the video input to a file. + + bit_depth selects the encoded bit depth; None keeps the video's native depth. """ pass @@ -83,6 +86,14 @@ class VideoInput(ABC): components = self.get_components() return components.images.shape[2], components.images.shape[1] + def get_bit_depth(self) -> int: + """ + Returns the bit depth of the video (e.g. 8 or 10). + + Default implementation returns 8; subclasses report their real depth. + """ + return 8 + def get_duration(self) -> float: """ Returns the duration of the video in seconds. diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 4a12ff9c1..92a1298c0 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -52,6 +52,12 @@ def get_open_write_kwargs( return open_kwargs +def video_stream_bit_depth(stream) -> int: + if stream is None or stream.format is None or not stream.format.components: + return 8 + return max(component.bits for component in stream.format.components) + + class VideoFromFile(VideoInput): """ Class representing video input from a file. @@ -97,6 +103,13 @@ class VideoFromFile(VideoInput): return stream.width, stream.height raise ValueError(f"No video stream found in file '{self.__file}'") + def get_bit_depth(self) -> int: + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) # Reset the BytesIO object to the beginning + with av.open(self.__file, mode="r") as container: + video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None + return video_stream_bit_depth(video_stream) + def get_duration(self) -> float: """ Returns the duration of the video in seconds. @@ -257,6 +270,7 @@ class VideoFromFile(VideoInput): image_format = 'gbrpf32le' process_image_format = lambda a: a + align_graph = None audio = None streams = [video_stream] @@ -310,7 +324,24 @@ class VideoFromFile(VideoInput): checked_alpha = True - img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) + # Fix non-deterministic video decode when the video width is not a multiple of 32 + # For non-yuvj pixel formats (all H.264/H.265 video) + if image_format in ('gbrpf32le', 'gbrapf32le') and frame.width % 32 != 0: + if align_graph is None: + pad_w = ((frame.width + 31) // 32) * 32 + g = av.filter.Graph() + g_src = g.add_buffer(width=frame.width, height=frame.height, + format=frame.format.name, time_base=video_stream.time_base) + g_pad = g.add('pad', f'{pad_w}:{frame.height}:0:0') + g_sink = g.add('buffersink') + g_src.link_to(g_pad) + g_pad.link_to(g_sink) + g.configure() + align_graph = (g, g_src, g_sink) + align_graph[1].push(frame) + img = np.ascontiguousarray(align_graph[2].pull().to_ndarray(format=image_format)[:, :frame.width]) + else: + img = frame.to_ndarray(format=image_format) if frame.rotation != 0: k = int(round(frame.rotation // 90)) img = np.rot90(img, k=k, axes=(0, 1)).copy() @@ -377,25 +408,32 @@ class VideoFromFile(VideoInput): format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, + bit_depth: int | None = None, ): if isinstance(self.__file, io.BytesIO): self.__file.seek(0) # Reset the BytesIO object to the beginning with av.open(self.__file, mode='r') as container: container_format = container.format.name - video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None + video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None + video_encoding = video_stream.codec.name if video_stream is not None else None + source_bit_depth = video_stream_bit_depth(video_stream) reuse_streams = True if format != VideoContainer.AUTO and format not in container_format.split(","): reuse_streams = False if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: reuse_streams = False + if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth: + reuse_streams = False if self.__start_time or self.__duration: reuse_streams = False if not reuse_streams: + if bit_depth is None: + bit_depth = source_bit_depth components = self.get_components_internal(container) video = VideoFromComponents(components) return video.save_to( - path, format=format, codec=codec, metadata=metadata + path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth, ) streams = container.streams @@ -451,8 +489,10 @@ class VideoFromComponents(VideoInput): Class representing video input from tensors. """ - def __init__(self, components: VideoComponents): + def __init__(self, components: VideoComponents, bit_depth: int = 8): self.__components = components + # Tensor components have no inherent bit depth; this is the depth used when encoding. + self.__bit_depth = bit_depth def get_components(self) -> VideoComponents: return VideoComponents( @@ -461,18 +501,26 @@ class VideoFromComponents(VideoInput): frame_rate=self.__components.frame_rate, ) + def get_bit_depth(self) -> int: + return self.__bit_depth + def save_to( self, path: str, format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, + bit_depth: int | None = None, ): """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: raise ValueError("Only H264 codec is supported for now") + # None means "use the depth this video was created with" (CreateVideo's choice). + if bit_depth is None: + bit_depth = self.__bit_depth + is_10bit = bit_depth >= 10 extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value @@ -488,10 +536,11 @@ class VideoFromComponents(VideoInput): frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000) # Create a video stream + pix_fmt = "yuv420p10le" if is_10bit else "yuv420p" video_stream = output.add_stream('h264', rate=frame_rate) video_stream.width = self.__components.images.shape[2] video_stream.height = self.__components.images.shape[1] - video_stream.pix_fmt = 'yuv420p' + video_stream.pix_fmt = pix_fmt # Create an audio stream audio_sample_rate = 1 @@ -505,9 +554,14 @@ class VideoFromComponents(VideoInput): # Encode video for i, frame in enumerate(self.__components.images): - img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) - frame = av.VideoFrame.from_ndarray(img, format='rgb24') - frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264 + if is_10bit: + # 16-bit RGB keeps float precision through the conversion to 10-bit YUV. + img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format="rgb48le") + else: + img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3) + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + frame = frame.reformat(format=pix_fmt) packet = video_stream.encode(frame) output.mux(packet) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index a3aa508ce..012fae3ac 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -755,6 +755,18 @@ class File3DKSPLAT(ComfyTypeIO): Type = File3D +@comfytype(io_type="FILE_3D_SPLAT_ANY") +class File3DSplatAny(ComfyTypeIO): + """General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat).""" + Type = File3D + + +@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY") +class File3DPointCloudAny(ComfyTypeIO): + """General point cloud file type - accepts any supported point cloud container (currently .ply).""" + Type = File3D + + @comfytype(io_type="HOOKS") class Hooks(ComfyTypeIO): if TYPE_CHECKING: @@ -1388,7 +1400,8 @@ class V3Data(TypedDict): class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, extra_pnginfo: Any, dynprompt: Any, - auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs): + auth_token_comfy_org: str, api_key_comfy_org: str, + comfy_usage_source: str = None, **kwargs): self.unique_id = unique_id """UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages).""" self.prompt = prompt @@ -1401,6 +1414,8 @@ class HiddenHolder: """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" self.api_key_comfy_org = api_key_comfy_org """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + self.comfy_usage_source = comfy_usage_source + """COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header.""" def __getattr__(self, key: str): '''If hidden variable not found, return None.''' @@ -1417,6 +1432,7 @@ class HiddenHolder: dynprompt=d.get(Hidden.dynprompt, None), auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None), api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), + comfy_usage_source=d.get(Hidden.comfy_usage_source, None), ) @classmethod @@ -1439,6 +1455,8 @@ class Hidden(str, Enum): """AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend.""" api_key_comfy_org = "API_KEY_COMFY_ORG" """API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend.""" + comfy_usage_source = "COMFY_USAGE_SOURCE" + """COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header.""" @dataclass @@ -1642,6 +1660,8 @@ class Schema: self.hidden.append(Hidden.auth_token_comfy_org) if Hidden.api_key_comfy_org not in self.hidden: self.hidden.append(Hidden.api_key_comfy_org) + if Hidden.comfy_usage_source not in self.hidden: + self.hidden.append(Hidden.comfy_usage_source) # if is an output_node, will need prompt and extra_pnginfo if self.is_output_node: if Hidden.prompt not in self.hidden: @@ -2336,6 +2356,8 @@ __all__ = [ "File3DSPLAT", "File3DSPZ", "File3DKSPLAT", + "File3DSplatAny", + "File3DPointCloudAny", "Hooks", "HookKeyframes", "TimestepsRange", diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index 6592f6b1d..b48713d41 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -285,7 +285,7 @@ class AudioSaveHelper: results = [] for batch_number, waveform in enumerate(audio["waveform"].cpu()): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.{format}" + file = f"{filename_with_batch_num}_{counter:05}.{format}" output_path = os.path.join(full_output_folder, file) # Use original sample rate initially diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index 9c4cfb9b6..9a7049ea2 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1310,13 +1310,6 @@ class KlingTaskStatus(str, Enum): failed = 'failed' -class KlingTextToVideoModelName(str, Enum): - kling_v1 = 'kling-v1' - kling_v1_6 = 'kling-v1-6' - kling_v2_1_master = 'kling-v2-1-master' - kling_v2_5_turbo = 'kling-v2-5-turbo' - - class KlingVideoGenAspectRatio(str, Enum): field_16_9 = '16:9' field_9_16 = '9:16' @@ -5179,7 +5172,7 @@ class KlingText2VideoRequest(BaseModel): duration: Optional[KlingVideoGenDuration] = '5' external_task_id: Optional[str] = Field(None, description='Customized Task ID') mode: Optional[KlingVideoGenMode] = 'std' - model_name: Optional[KlingTextToVideoModelName] = 'kling-v1' + model_name: Optional[str] = 'kling-v1' negative_prompt: Optional[str] = Field( None, description='Negative text prompt', max_length=2500 ) diff --git a/comfy_api_nodes/apis/bfl.py b/comfy_api_nodes/apis/bfl.py index 2ad651122..4c950da84 100644 --- a/comfy_api_nodes/apis/bfl.py +++ b/comfy_api_nodes/apis/bfl.py @@ -43,6 +43,7 @@ class BFLFluxEraseRequest(BaseModel): "white (255) marks areas to remove, black (0) marks areas to preserve.", ) dilate_pixels: int = Field(10) + seed: int | None = Field(None) output_format: str = Field("png") diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py index e08a519a8..7a98428c3 100644 --- a/comfy_api_nodes/apis/bria.py +++ b/comfy_api_nodes/apis/bria.py @@ -97,3 +97,28 @@ class BriaRemoveVideoBackgroundResult(BaseModel): class BriaRemoveVideoBackgroundResponse(BaseModel): status: str = Field(...) result: BriaRemoveVideoBackgroundResult | None = Field(None) + + +class BriaVideoGreenScreenRequest(BaseModel): + video: str = Field(..., description="Publicly accessible URL of the input video.") + green_shade: str = Field( + default="broadcast_green", + description="Solid chroma-key shade applied behind the foreground " + "(broadcast_green, chroma_green, or blue_screen).", + ) + output_container_and_codec: str = Field(...) + preserve_audio: bool = Field(True) + seed: int = Field(...) + + +class BriaVideoReplaceBackgroundRequest(BaseModel): + video: str = Field(..., description="Publicly accessible URL of the input (foreground) video.") + background_url: str = Field( + ..., + description="Publicly accessible URL of the background image or video to composite behind " + "the foreground. Stretched to the foreground frame; match its aspect ratio for " + "undistorted results.", + ) + output_container_and_codec: str = Field(...) + preserve_audio: bool = Field(True) + seed: int = Field(...) diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 22879fe18..caaba8f36 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel): startOffset: GeminiOffset | None = Field(None) +class GeminiThinkingConfig(BaseModel): + includeThoughts: bool | None = Field(None) + thinkingLevel: str = Field(...) + + class GeminiGenerationConfig(BaseModel): - maxOutputTokens: int | None = Field(None, ge=16, le=8192) + maxOutputTokens: int | None = Field(None, ge=16, le=65536) seed: int | None = Field(None) stopSequences: list[str] | None = Field(None) temperature: float | None = Field(None, ge=0.0, le=2.0) topK: int | None = Field(None, ge=1) topP: float | None = Field(None, ge=0.0, le=1.0) + thinkingConfig: GeminiThinkingConfig | None = Field(None) class GeminiImageOutputOptions(BaseModel): @@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel): imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions) -class GeminiThinkingConfig(BaseModel): - includeThoughts: bool | None = Field(None) - thinkingLevel: str = Field(...) - - class GeminiImageGenerationConfig(GeminiGenerationConfig): responseModalities: list[str] | None = Field(None) imageConfig: GeminiImageConfig | None = Field(None) diff --git a/comfy_api_nodes/apis/runway.py b/comfy_api_nodes/apis/runway.py index df6f2b845..6878aa6f0 100644 --- a/comfy_api_nodes/apis/runway.py +++ b/comfy_api_nodes/apis/runway.py @@ -67,15 +67,6 @@ class RunwayImageToVideoResponse(BaseModel): id: Optional[str] = Field(None, description='Task ID') -class RunwayTaskStatusEnum(str, Enum): - SUCCEEDED = 'SUCCEEDED' - RUNNING = 'RUNNING' - FAILED = 'FAILED' - PENDING = 'PENDING' - CANCELLED = 'CANCELLED' - THROTTLED = 'THROTTLED' - - class RunwayTaskStatusResponse(BaseModel): createdAt: datetime = Field(..., description='Task creation timestamp') id: str = Field(..., description='Task ID') @@ -86,7 +77,7 @@ class RunwayTaskStatusResponse(BaseModel): ge=0.0, le=1.0, ) - status: RunwayTaskStatusEnum + status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED") class Model4(str, Enum): @@ -125,3 +116,144 @@ class RunwayTextToImageRequest(BaseModel): class RunwayTextToImageResponse(BaseModel): id: Optional[str] = Field(None, description='Task ID') + + +class RunwayAleph2IO: + """Custom socket types for chaining Aleph2 guidance images.""" + + KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME" + PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE" + + +# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to +# choose the request model below. The values match the Aleph2 keyframe union field names. +KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video +KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration + +# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`. +PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video +PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration + + +class RunwayAleph2KeyframeItem: + """A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``).""" + + def __init__(self, image, mode: str, value: float): + self.image = image + self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT + self.value = value + + +class RunwayAleph2KeyframeChain: + """An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes.""" + + def __init__(self): + self.items: list[RunwayAleph2KeyframeItem] = [] + + def add(self, item: RunwayAleph2KeyframeItem) -> None: + self.items.append(item) + + def clone(self) -> "RunwayAleph2KeyframeChain": + c = RunwayAleph2KeyframeChain() + c.items = list(self.items) + return c + + +class RunwayAleph2PromptImageItem: + """A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``).""" + + def __init__(self, image, mode: str, value: float): + self.image = image + self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION + self.value = value + + +class RunwayAleph2PromptImageChain: + """An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes.""" + + def __init__(self): + self.items: list[RunwayAleph2PromptImageItem] = [] + + def add(self, item: RunwayAleph2PromptImageItem) -> None: + self.items.append(item) + + def clone(self) -> "RunwayAleph2PromptImageChain": + c = RunwayAleph2PromptImageChain() + c.items = list(self.items) + return c + + +class RunwayAleph2KeyframeSeconds(BaseModel): + seconds: float = Field( + ..., + description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.", + ge=0.0, + ) + uri: str = Field(...) + + +class RunwayAleph2KeyframeAt(BaseModel): + at: float = Field( + ..., + description="Position as a fraction [0.0, 1.0] of the input video duration.", + ge=0.0, + le=1.0, + ) + uri: str = Field(...) + + +class RunwayAleph2TimestampPosition(BaseModel): + type: str = Field(default="timestamp") + timestampSeconds: float = Field( + ..., + description="Absolute timestamp in seconds from the start of the output video.", + ge=0.0, + ) + + +class RunwayAleph2RelativePosition(BaseModel): + type: str = Field(default="position") + positionPercentage: float = Field( + ..., + description="Position as a fraction [0.0, 1.0] of the total output video duration.", + ge=0.0, + le=1.0, + ) + + +class RunwayAleph2PromptImage(BaseModel): + position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition + uri: str = Field(...) + + +class RunwayAleph2ContentModeration(BaseModel): + publicFigureThreshold: str = Field( + ..., + description='When set to "low", the content moderation system is less strict about ' + 'recognizable public figures. One of "auto" or "low".', + ) + + +class RunwayAleph2Request(BaseModel): + model: str = Field(default="aleph2") + promptText: str = Field( + ..., + description="A non-empty string describing what should appear in the output.", + min_length=1, + max_length=1000, + ) + videoUri: str = Field(...) + seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295) + contentModeration: RunwayAleph2ContentModeration = Field(...) + keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field( + None, + description="Timed guidance images placed at specific points in the input video. Up to 5.", + ) + promptImage: list[RunwayAleph2PromptImage] | None = Field( + None, + description="Up to 5 image keyframes for guiding the edit at specific points in the output video.", + ) + + +class RunwayAleph2Response(BaseModel): + id: str | None = Field(None, description="Task ID") diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 79961ff9d..259c54ef9 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -534,6 +534,15 @@ class FluxEraseNode(IO.ComfyNode): max=25, tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.", ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), ], outputs=[IO.Image.Output()], hidden=[ @@ -553,6 +562,7 @@ class FluxEraseNode(IO.ComfyNode): image: Input.Image, mask: Input.Image, dilate_pixels: int = 10, + seed: int = 0, ) -> IO.NodeOutput: validate_image_dimensions(image, min_width=256, min_height=256) mask = resize_mask_to_image(mask, image) @@ -565,6 +575,7 @@ class FluxEraseNode(IO.ComfyNode): image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed mask=mask, dilate_pixels=dilate_pixels, + seed=seed, ), ) diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index ce2c9e9be..090154afb 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -12,6 +12,8 @@ from comfy_api_nodes.apis.bria import ( BriaRemoveVideoBackgroundRequest, BriaRemoveVideoBackgroundResponse, BriaStatusResponse, + BriaVideoGreenScreenRequest, + BriaVideoReplaceBackgroundRequest, InputModerationSettings, ) from comfy_api_nodes.util import ( @@ -287,7 +289,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""", ), ) @@ -319,6 +321,161 @@ class BriaRemoveVideoBackground(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) +class BriaVideoGreenScreen(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaVideoGreenScreen", + display_name="Bria Video Green Screen", + category="partner/video/Bria", + description="Replace a video's background with a solid chroma-key screen using Bria.", + inputs=[ + IO.Video.Input("video"), + IO.Combo.Input( + "green_shade", + options=["broadcast_green", "chroma_green", "blue_screen"], + tooltip="Solid chroma-key shade applied behind the foreground: " + "broadcast_green (#00B140), chroma_green (#00FF00), or blue_screen (#0000FF).", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + green_shade: str, + seed: int, + ) -> IO.NodeOutput: + validate_video_duration(video, max_duration=60.0) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/video/edit/green_screen", method="POST"), + data=BriaVideoGreenScreenRequest( + video=await upload_video_to_comfyapi(cls, video), + green_shade=green_shade, + output_container_and_codec="mp4_h264", + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveVideoBackgroundResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) + + +class BriaVideoReplaceBackground(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaVideoReplaceBackground", + display_name="Bria Video Replace Background", + category="partner/video/Bria", + description="Replace a video's background with a supplied image or video using Bria. " + "The output keeps the foreground's resolution and frame rate; a background with a " + "different aspect ratio is stretched to fit, so match it for undistorted results.", + inputs=[ + IO.Video.Input("video", tooltip="Foreground video whose background is replaced."), + IO.Image.Input( + "background_image", + optional=True, + tooltip="Background image to composite behind the foreground. " + "Provide either a background image or a background video, not both.", + ), + IO.Video.Input( + "background_video", + optional=True, + tooltip="Background video to composite behind the foreground. " + "Provide either a background image or a background video, not both.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + seed: int, + background_image: Input.Image | None = None, + background_video: Input.Video | None = None, + ) -> IO.NodeOutput: + if (background_image is None) == (background_video is None): + raise ValueError("Provide either a background image or a background video, not both.") + validate_video_duration(video, max_duration=60.0) + if background_video is not None: + validate_video_duration(background_video, max_duration=60.0) + background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background") + else: + # Bria's replace_background 500s on RGBA, so drop the alpha channel before upload. + background_url = await upload_image_to_comfyapi( + cls, background_image[:, :, :, :3], wait_label="Uploading background" + ) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"), + data=BriaVideoReplaceBackgroundRequest( + video=await upload_video_to_comfyapi(cls, video), + background_url=background_url, + output_container_and_codec="mp4_h264", + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveVideoBackgroundResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) + + def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]: """Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask. @@ -376,7 +533,7 @@ class BriaTransparentVideoBackground(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""", ), ) @@ -416,6 +573,8 @@ class BriaExtension(ComfyExtension): BriaImageEditNode, BriaRemoveImageBackground, BriaRemoveVideoBackground, + BriaVideoGreenScreen, + BriaVideoReplaceBackground, BriaTransparentVideoBackground, ] diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index d8885a7e5..c30ddc446 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -7,6 +7,7 @@ from io import BytesIO import torch from typing_extensions import override +from comfy.utils import common_upscale from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, @@ -131,6 +132,44 @@ def _prepare_seedance_image(image: Input.Image) -> Input.Image: return image +# Supported output aspect ratios, used to pre-size FLF frames to matching pixel pair to avoid the 1080p stretch jump. +SEEDANCE2_RATIO_WH = { + "16:9": (16, 9), + "4:3": (4, 3), + "1:1": (1, 1), + "3:4": (3, 4), + "9:16": (9, 16), + "21:9": (21, 9), +} +SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080} + + +def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]: + """Exact supported output (width, height) for (resolution, ratio). + + The shorter side equals the resolution number (e.g. 1080p 16:9 -> 1920x1080). For ratio + "adaptive" (or any unexpected value) the ratio is derived from the image's own aspect, snapped + to the nearest supported ratio, so the output keeps the frame's orientation. + """ + short = SEEDANCE2_RES_SHORT_SIDE[resolution] + if ratio not in SEEDANCE2_RATIO_WH: + aspect = image.shape[-2] / image.shape[-3] # W / H; tensor is (B, H, W, C) + ratio = min(SEEDANCE2_RATIO_WH, key=lambda k: abs(SEEDANCE2_RATIO_WH[k][0] / SEEDANCE2_RATIO_WH[k][1] - aspect)) + rw, rh = SEEDANCE2_RATIO_WH[ratio] + if rw >= rh: # landscape or square: shorter side is the height + out_w, out_h = round(short * rw / rh), short + else: # portrait: shorter side is the width + out_w, out_h = short, round(short * rh / rw) + return out_w - out_w % 2, out_h - out_h % 2 + + +def _resize_to_exact(image: torch.Tensor, width: int, height: int) -> torch.Tensor: + """Center-crop to the target aspect and resize to exactly width x height (lanczos).""" + samples = image.movedim(-1, 1) # (B, H, W, C) -> (B, C, H, W) + resized = common_upscale(samples, width, height, "lanczos", "center") + return resized.movedim(1, -1) + + async def _resolve_reference_assets( cls: type[IO.ComfyNode], asset_ids: list[str], @@ -1790,10 +1829,28 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): if last_frame is not None and last_frame_asset_id: raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.") - if first_frame is not None: - first_frame = _prepare_seedance_image(first_frame) - if last_frame is not None: - last_frame = _prepare_seedance_image(last_frame) + request_ratio = model["ratio"] + if first_frame_asset_id or last_frame_asset_id: + if first_frame is not None: + first_frame = _prepare_seedance_image(first_frame) + if last_frame is not None: + last_frame = _prepare_seedance_image(last_frame) + else: + # The 1080p FLF stretch fix (pre-size frames to a supported pixel pair + submit ratio="adaptive") + # only applies to local image inputs we can resize. + request_ratio = "adaptive" + target_dims: tuple[int, int] | None = None + if first_frame is not None: + validate_image_aspect_ratio(first_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_dimensions(first_frame, min_width=300, min_height=300) + target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], first_frame) + first_frame = _resize_to_exact(first_frame, *target_dims) + if last_frame is not None: + validate_image_aspect_ratio(last_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_dimensions(last_frame, min_width=300, min_height=300) + if target_dims is None: + target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], last_frame) + last_frame = _resize_to_exact(last_frame, *target_dims) asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a] image_assets: dict[str, str] = {} @@ -1844,7 +1901,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): content=content, generate_audio=model["generate_audio"], resolution=model["resolution"], - ratio=model["ratio"], + ratio=request_ratio, duration=model["duration"], seed=seed, watermark=watermark, diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index e75ef3835..3d4be6065 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -8,7 +8,7 @@ import os from enum import Enum from fnmatch import fnmatch from io import BytesIO -from typing import Literal +from typing import Any, Literal import torch from typing_extensions import override @@ -19,6 +19,7 @@ from comfy_api_nodes.apis.gemini import ( GeminiContent, GeminiFileData, GeminiGenerateContentRequest, + GeminiGenerationConfig, GeminiGenerateContentResponse, GeminiImageConfig, GeminiImageGenerateContentRequest, @@ -40,13 +41,18 @@ from comfy_api_nodes.util import ( get_number_of_images, sync_op, tensor_to_base64_string, + upload_audio_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_string, video_to_base64_string, ) GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB +GEMINI_URL_INPUT_BUDGET = 10 +GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024 GEMINI_IMAGE_SYS_PROMPT = ( "You are an expert image-generation engine. You must ALWAYS produce an image.\n" "Interpret all user input—regardless of " @@ -285,6 +291,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N return final_price / 1_000_000.0 +def create_video_parts(video_input: Input.Video) -> list[GeminiPart]: + """Convert a single video input to Gemini API compatible parts (inline MP4/H.264).""" + base_64_string = video_to_base64_string( + video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 + ) + return [ + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.video_mp4, + data=base_64_string, + ) + ) + ] + + +def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]: + """Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item).""" + audio_parts: list[GeminiPart] = [] + for batch_index in range(audio_input["waveform"].shape[0]): + # Recreate an IO.AUDIO object for the given batch dimension index + audio_at_index = Input.Audio( + waveform=audio_input["waveform"][batch_index].unsqueeze(0), + sample_rate=audio_input["sample_rate"], + ) + # Convert to MP3 format for compatibility with Gemini API + audio_bytes = audio_to_base64_string( + audio_at_index, + container_format="mp3", + codec_name="libmp3lame", + ) + audio_parts.append( + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.audio_mp3, + data=audio_bytes, + ) + ) + ) + return audio_parts + + +def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]: + """Expand any batched image tensors into individual (H, W, C) frames, preserving order.""" + frames: list[torch.Tensor] = [] + for img in images: + if len(img.shape) == 4: + frames.extend(img[i] for i in range(img.shape[0])) + else: + frames.append(img) + return frames + + +def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]: + """Expand any batched audio inputs into individual single-clip audio inputs, preserving order.""" + clips: list[Input.Audio] = [] + for audio in audios: + waveform = audio["waveform"] + for i in range(waveform.shape[0]): + clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"])) + return clips + + +async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart: + """Upload a single media unit to ComfyAPI storage and return a fileData (URL) part.""" + if kind == "image": + url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image") + return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url)) + if kind == "audio": + url = await upload_audio_to_comfyapi( + cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3" + ) + return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url)) + url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video") + return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url)) + + +def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]: + """Encode a single media unit as an inline base64 part; returns (part, base64_length).""" + if kind == "image": + data = tensor_to_base64_string(payload, mime_type="image/webp") + mime = GeminiMimeType.image_webp + elif kind == "audio": + data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame") + mime = GeminiMimeType.audio_mp3 + else: + data = video_to_base64_string( + payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 + ) + mime = GeminiMimeType.video_mp4 + return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data) + + +async def build_gemini_media_parts( + cls: type[IO.ComfyNode], + images: list[Input.Image], + audios: list[Input.Audio], + videos: list[Input.Video], + *, + url_budget: int = GEMINI_URL_INPUT_BUDGET, + max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES, +) -> list[GeminiPart]: + """Build Gemini parts for multimodal inputs (images, audio, video). + + fileData URLs are preferred for every media type: the upload is fetched directly by the + model, keeping the request body tiny regardless of media size. The URL budget is shared + across all media and assigned largest-first (video, then audio, then images), so that if it + is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline + payload is capped by `max_inline_bytes`. + """ + units: list[tuple[str, Any]] = ( + [("video", v) for v in videos] + + [("audio", a) for a in _flatten_audio(audios)] + + [("image", f) for f in _flatten_images(images)] + ) + + parts: list[GeminiPart] = [] + url_used = 0 + inline_bytes = 0 + for kind, payload in units: + if url_used < url_budget: + parts.append(await _media_url_part(cls, kind, payload)) + url_used += 1 + continue + part, nbytes = _media_inline_part(kind, payload) + inline_bytes += nbytes + if inline_bytes > max_inline_bytes: + raise ValueError( + f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first " + f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media." + ) + parts.append(part) + return parts + + class GeminiNode(IO.ComfyNode): """ Node to generate text responses from a Gemini model. @@ -407,58 +547,9 @@ class GeminiNode(IO.ComfyNode): ) """, ), + is_deprecated=True, ) - @classmethod - def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: - """Convert video input to Gemini API compatible parts.""" - - base_64_string = video_to_base64_string( - video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 - ) - return [ - GeminiPart( - inlineData=GeminiInlineData( - mimeType=GeminiMimeType.video_mp4, - data=base_64_string, - ) - ) - ] - - @classmethod - def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]: - """ - Convert audio input to Gemini API compatible parts. - - Args: - audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate. - - Returns: - List of GeminiPart objects containing the encoded audio. - """ - audio_parts: list[GeminiPart] = [] - for batch_index in range(audio_input["waveform"].shape[0]): - # Recreate an IO.AUDIO object for the given batch dimension index - audio_at_index = Input.Audio( - waveform=audio_input["waveform"][batch_index].unsqueeze(0), - sample_rate=audio_input["sample_rate"], - ) - # Convert to MP3 format for compatibility with Gemini API - audio_bytes = audio_to_base64_string( - audio_at_index, - container_format="mp3", - codec_name="libmp3lame", - ) - audio_parts.append( - GeminiPart( - inlineData=GeminiInlineData( - mimeType=GeminiMimeType.audio_mp3, - data=audio_bytes, - ) - ) - ) - return audio_parts - @classmethod async def execute( cls, @@ -482,9 +573,9 @@ class GeminiNode(IO.ComfyNode): if images is not None: parts.extend(await create_image_parts(cls, images)) if audio is not None: - parts.extend(cls.create_audio_parts(audio)) + parts.extend(create_audio_parts(audio)) if video is not None: - parts.extend(cls.create_video_parts(video)) + parts.extend(create_video_parts(video)) if files is not None: parts.extend(files) @@ -512,6 +603,210 @@ class GeminiNode(IO.ComfyNode): return IO.NodeOutput(output_text or "Empty response from Gemini model...") +GEMINI_V2_MODELS: dict[str, str] = { + "Gemini 3.1 Pro": "gemini-3.1-pro-preview", + "Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview", +} + + +def _gemini_text_model_inputs(thinking_default: str) -> list[Input]: + """Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls).""" + return [ + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("image"), + names=[f"image_{i}" for i in range(1, 17)], + min=0, + ), + tooltip="Optional image(s) to use as context for the model. Up to 16 images.", + ), + IO.Autogrow.Input( + "audio", + template=IO.Autogrow.TemplateNames( + IO.Audio.Input("audio"), + names=["audio_1"], + min=0, + ), + tooltip="Optional audio clip to use as context for the model.", + ), + IO.Autogrow.Input( + "video", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("video"), + names=["video_1"], + min=0, + ), + tooltip="Optional video clip to use as context for the model.", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Input Files node.", + ), + IO.Combo.Input( + "thinking_level", + options=["LOW", "HIGH"], + default=thinking_default, + tooltip="How hard the model reasons internally before answering. " + "HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.", + ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.01, + tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=0.95, + min=0.0, + max=1.0, + step=0.01, + tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.", + advanced=True, + ), + IO.Int.Input( + "max_output_tokens", + default=32768, + min=16, + max=65536, + tooltip="Maximum tokens to generate, including the model's internal thinking. " + "With thinking_level HIGH, a low value can leave no room for the answer; raise this if " + "responses come back empty or truncated. The model stops early when finished, so a higher " + "cap costs nothing extra for short replies.", + advanced=True, + ), + ] + + +class GeminiNodeV2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiNodeV2", + display_name="Google Gemini", + category="partner/text/Gemini", + essentials_category="Text Generation", + description="Generate text responses with Google's Gemini models. Provide a text prompt and, " + "optionally, one or more images, audio clips, videos, or files as multimodal context.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text input to the model. Include detailed instructions, questions, or context.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")), + IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")), + ], + tooltip="The Gemini model used to generate the response.", + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + advanced=True, + tooltip="Foundational instructions that dictate the model's behavior.", + ), + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "lite") ? { + "type": "list_usd", + "usd": [0.00025, 0.0015], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } : { + "type": "list_usd", + "usd": [0.002, 0.012], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + model_id = GEMINI_V2_MODELS[model["model"]] + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + images = [t for t in (model.get("images") or {}).values() if t is not None] + audios = [a for a in (model.get("audio") or {}).values() if a is not None] + videos = [v for v in (model.get("video") or {}).values() if v is not None] + if images or audios or videos: + parts.extend(await build_gemini_media_parts(cls, images, audios, videos)) + files = model.get("files") + if files is not None: + parts.extend(files) + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"), + data=GeminiGenerateContentRequest( + contents=[ + GeminiContent( + role=GeminiRole.user, + parts=parts, + ) + ], + generationConfig=GeminiGenerationConfig( + temperature=model["temperature"], + topP=model["top_p"], + maxOutputTokens=model["max_output_tokens"], + seed=seed if seed > 0 else None, + thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]), + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + + output_text = get_text_from_response(response) + return IO.NodeOutput(output_text or "Empty response from Gemini model...") + + class GeminiInputFiles(IO.ComfyNode): """ Loads and formats input files for use with the Gemini API. @@ -1129,6 +1424,26 @@ class GeminiNanoBanana2V2(IO.ComfyNode): tooltip="Foundational instructions that dictate an AI's behavior.", advanced=True, ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.01, + optional=True, + tooltip="Controls randomness in generation. Lower is more focused/deterministic.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=0.95, + min=0.0, + max=1.0, + step=0.01, + optional=True, + tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.", + advanced=True, + ), ], outputs=[ IO.Image.Output(), @@ -1165,6 +1480,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode): seed: int, response_modalities: str, system_prompt: str = "", + temperature: float = 1.0, + top_p: float = 0.95, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) model_choice = model["model"] @@ -1204,6 +1521,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode): responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=image_config, thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]), + temperature=temperature, + topP=top_p, ), systemInstruction=gemini_system_prompt, ), @@ -1222,6 +1541,7 @@ class GeminiExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ GeminiNode, + GeminiNodeV2, GeminiImage, GeminiImage2, GeminiNanoBanana2, diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index d11e42540..c81d3503d 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -436,7 +436,7 @@ async def execute_text2video( negative_prompt=negative_prompt if negative_prompt else None, duration=KlingVideoGenDuration(duration), mode=KlingVideoGenMode(model_mode), - model_name=KlingVideoGenModelName(model_name), + model_name=model_name, cfg_scale=cfg_scale, aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), camera_control=camera_control, diff --git a/comfy_api_nodes/nodes_krea.py b/comfy_api_nodes/nodes_krea.py index 34369f05f..b9e6268f2 100644 --- a/comfy_api_nodes/nodes_krea.py +++ b/comfy_api_nodes/nodes_krea.py @@ -42,9 +42,11 @@ async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Ima _MODEL_MEDIUM = "Krea 2 Medium" +_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo" _MODEL_LARGE = "Krea 2 Large" _MODEL_ENDPOINTS: dict[str, str] = { _MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium", + _MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo", _MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large", } @@ -57,7 +59,7 @@ _UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F def _krea_model_inputs() -> list: - """Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo.""" + """Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo.""" return [ IO.Combo.Input( "aspect_ratio", @@ -123,6 +125,7 @@ class Krea2ImageNode(IO.ComfyNode): "model", options=[ IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()), + IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()), IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()), ], tooltip="Krea 2 Medium is best for expressive illustrations; " @@ -151,14 +154,15 @@ class Krea2ImageNode(IO.ComfyNode): ), expr=""" ( - $isLarge := widgets.model = "krea 2 large"; + $rates := { + "krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02}, + "krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04}, + "krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07} + }; + $r := $lookup($rates, widgets.model); $hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0; $hasStyle := $lookup(inputs, "model.style_reference").connected; - $usd := $hasMoodboard - ? ($isLarge ? 0.07 : 0.04) - : ($hasStyle - ? ($isLarge ? 0.065 : 0.035) - : ($isLarge ? 0.06 : 0.03)); + $usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text); {"type":"usd","usd": $usd} ) """, diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 0fe5fb9d0..ad62f2164 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -9,6 +9,7 @@ from PIL import Image from typing_extensions import override import folder_paths +from comfy.utils import common_upscale from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.openai import ( InputFileContent, @@ -62,7 +63,8 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten timeout: Request timeout in seconds. Defaults to None (no timeout). Returns: - A torch.Tensor representing the image (1, H, W, C). + A torch.Tensor of shape (N, H, W, C) with all returned images; images whose + dimensions differ from the first image's are resized to match it. Raises: ValueError: If the response is not valid. @@ -89,6 +91,14 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten arr = np.asarray(pil_img).astype(np.float32) / 255.0 image_tensors.append(torch.from_numpy(arr)) + # With size="auto" the API can return images whose dimensions differ by a few pixels within a single response + # resize them to the first image's dimensions so they can be stacked into one batch. + ref_h, ref_w = image_tensors[0].shape[:2] + for i, t in enumerate(image_tensors): + if t.shape[:2] != (ref_h, ref_w): + samples = t.unsqueeze(0).movedim(-1, 1) + samples = common_upscale(samples, ref_w, ref_h, "bilinear", "center") + image_tensors[i] = samples.movedim(1, -1).squeeze(0) return torch.stack(image_tensors, dim=0) diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index b9c5c81a1..013a193d9 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -30,13 +30,33 @@ from comfy_api_nodes.apis.runway import ( Model4, ReferenceImage, RunwayTextToImageAspectRatioEnum, + RunwayAleph2IO, + RunwayAleph2KeyframeChain, + RunwayAleph2KeyframeItem, + RunwayAleph2PromptImageChain, + RunwayAleph2PromptImageItem, + RunwayAleph2Request, + RunwayAleph2Response, + RunwayAleph2KeyframeSeconds, + RunwayAleph2KeyframeAt, + RunwayAleph2PromptImage, + RunwayAleph2TimestampPosition, + RunwayAleph2RelativePosition, + RunwayAleph2ContentModeration, + KEYFRAME_MODE_SECONDS, + KEYFRAME_MODE_AT, + PROMPT_IMAGE_MODE_TIMESTAMP, + PROMPT_IMAGE_MODE_POSITION, ) from comfy_api_nodes.util import ( image_tensor_pair_to_batch, validate_string, validate_image_dimensions, validate_image_aspect_ratio, + validate_video_duration, upload_images_to_comfyapi, + upload_image_to_comfyapi, + upload_video_to_comfyapi, download_url_to_video_output, download_url_to_image_tensor, ApiEndpoint, @@ -45,6 +65,7 @@ from comfy_api_nodes.util import ( ) PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" +PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" PATH_GET_TASK_STATUS = "/proxy/runway/tasks" @@ -53,12 +74,6 @@ AVERAGE_DURATION_FLF_SECONDS = 256 AVERAGE_DURATION_T2I_SECONDS = 41 -class RunwayApiError(Exception): - """Base exception for Runway API errors.""" - - pass - - class RunwayGen4TurboAspectRatio(str, Enum): """Aspect ratios supported for Image to Video API when using gen4_turbo model.""" @@ -84,14 +99,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None: return None -def extract_progress_from_task_status( - response: TaskStatusResponse, -) -> float | None: - if hasattr(response, "progress") and response.progress is not None: - return response.progress * 100 - return None - - def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the image URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: @@ -102,14 +109,13 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: async def get_response( cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None ) -> TaskStatusResponse: - """Poll the task status until it is finished then get the response.""" return await poll_op( cls, ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"), response_model=TaskStatusResponse, - status_extractor=lambda r: r.status.value, + status_extractor=lambda r: r.status, estimated_duration=estimated_duration, - progress_extractor=extract_progress_from_task_status, + progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None, ) @@ -127,7 +133,7 @@ async def generate_video( final_response = await get_response(cls, initial_response.id, estimated_duration) if not final_response.output: - raise RunwayApiError("Runway task succeeded but no video data found in response.") + raise ValueError("Runway task succeeded but no video data found in response.") video_url = get_video_url_from_task_status(final_response) return await download_url_to_video_output(video_url) @@ -410,7 +416,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): mime_type="image/png", ) if len(download_urls) != 2: - raise RunwayApiError("Failed to upload one or more images to comfy api.") + raise ValueError("Failed to upload one or more images to comfy api.") return IO.NodeOutput( await generate_video( @@ -514,11 +520,321 @@ class RunwayTextToImageNode(IO.ComfyNode): estimated_duration=AVERAGE_DURATION_T2I_SECONDS, ) if not final_response.output: - raise RunwayApiError("Runway task succeeded but no image data found in response.") + raise ValueError("Runway task succeeded but no image data found in response.") return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response))) +_TIMING_ABSOLUTE = "Absolute time (seconds)" +_TIMING_FRACTION = "Fraction of duration (0.0-1.0)" + + +class RunwayAleph2KeyframeNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RunwayAleph2KeyframeNode", + display_name="Runway Aleph2 Keyframe", + category="partner/video/Runway", + description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 " + "steers the edit at that point of your footage. Connect this to the 'keyframes' input of " + "the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional " + "'keyframes' input below.", + inputs=[ + IO.Image.Input( + "image", + tooltip="The guidance image to apply at the chosen moment of the input video.", + ), + IO.DynamicCombo.Input( + "timing", + options=[ + IO.DynamicCombo.Option( + _TIMING_ABSOLUTE, + [ + IO.Float.Input( + "seconds", + default=0.0, + min=0.0, + max=30.0, + step=0.1, + display_mode=IO.NumberDisplay.number, + tooltip="Time in seconds from start of the input video where this image applies.", + ), + ], + ), + IO.DynamicCombo.Option( + _TIMING_FRACTION, + [ + IO.Float.Input( + "fraction", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Where in the input video this image applies, " + "as a fraction of its duration (0.0 = start, 1.0 = end).", + ), + ], + ), + ], + tooltip="How to place this image on the input video's timeline.", + ), + IO.Custom(RunwayAleph2IO.KEYFRAME).Input( + "keyframes", + optional=True, + tooltip="Optional earlier keyframes to chain with this one.", + ), + ], + outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")], + ) + + @classmethod + def execute( + cls, + image: Input.Image, + timing: dict, + keyframes: RunwayAleph2KeyframeChain | None = None, + ) -> IO.NodeOutput: + chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain() + if timing["timing"] == _TIMING_ABSOLUTE: + mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"]) + else: + mode, value = KEYFRAME_MODE_AT, float(timing["fraction"]) + chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value)) + return IO.NodeOutput(chain) + + +class RunwayAleph2PromptImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RunwayAleph2PromptImageNode", + display_name="Runway Aleph2 Prompt Image", + category="partner/video/Runway", + description="Anchor a guidance image to a moment of the output (result) video, to guide what " + "the edited video looks like at that point. Connect this to the 'prompt_images' input of the " + "Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional " + "'prompt_images' input below.", + inputs=[ + IO.Image.Input( + "image", + tooltip="The guidance image to place at the chosen moment of the output video.", + ), + IO.DynamicCombo.Input( + "position", + options=[ + IO.DynamicCombo.Option( + _TIMING_ABSOLUTE, + [ + IO.Float.Input( + "seconds", + default=0.0, + min=0.0, + max=30.0, + step=0.1, + display_mode=IO.NumberDisplay.number, + tooltip="Time in seconds from start of the output video where this image applies.", + ), + ], + ), + IO.DynamicCombo.Option( + _TIMING_FRACTION, + [ + IO.Float.Input( + "fraction", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Where in the output video this image applies, " + "as a fraction of its duration (0.0 = start, 1.0 = end).", + ), + ], + ), + ], + tooltip="How to place this image on the output video's timeline.", + ), + IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input( + "prompt_images", + optional=True, + tooltip="Optional earlier prompt images to chain with this one.", + ), + ], + outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")], + ) + + @classmethod + def execute( + cls, + image: Input.Image, + position: dict, + prompt_images: RunwayAleph2PromptImageChain | None = None, + ) -> IO.NodeOutput: + chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain() + if position["position"] == _TIMING_ABSOLUTE: + mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"]) + else: + mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"]) + chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value)) + return IO.NodeOutput(chain) + + +class RunwayAleph2VideoToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RunwayAleph2VideoToVideoNode", + display_name="Runway Aleph2 Video to Video", + category="partner/video/Runway", + description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms " + "your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping " + "the original motion and timing; the output resolution matches the input video, which must be " + "2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to " + "the input video) or prompt images (anchored to the output video) - use one or the other, not both.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Describes what should appear in the output (1-1000 characters).", + ), + IO.Video.Input( + "video", + tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967295, + step=1, + control_after_generate=True, + display_mode=IO.NumberDisplay.number, + tooltip="Random seed for generation", + ), + IO.Combo.Input( + "public_figure_threshold", + options=["auto", "low"], + default="low", + tooltip="Content moderation for recognizable public figures.", + ), + IO.Custom(RunwayAleph2IO.KEYFRAME).Input( + "keyframes", + optional=True, + tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). " + "Use keyframes or prompt images, not both.", + ), + IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input( + "prompt_images", + optional=True, + tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). " + "Use keyframes or prompt images, not both.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + video: Input.Video, + seed: int, + public_figure_threshold: str = "low", + keyframes: RunwayAleph2KeyframeChain | None = None, + prompt_images: RunwayAleph2PromptImageChain | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=1000) + validate_video_duration( + video, + min_duration=2.0, + max_duration=30.0, + ) + try: + fps = float(video.get_frame_rate()) + except Exception: + fps = None + if fps is not None and fps > 30.0 + 0.01: + raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.") + + if (keyframes and keyframes.items) and (prompt_images and prompt_images.items): + raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.") + + video_duration: float | None = None + try: + video_duration = video.get_duration() + except Exception: + video_duration = None + + def _check_seconds(value: float, label: str) -> None: + if video_duration is not None and value > video_duration + 0.0001: + raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).") + + video_url = await upload_video_to_comfyapi(cls, video) + + keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = [] + if keyframes is not None: + if len(keyframes.items) > 5: + raise ValueError("Aleph2 supports at most 5 keyframes.") + for item in keyframes.items: + image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png") + if item.mode == KEYFRAME_MODE_SECONDS: + _check_seconds(item.value, "Keyframe timestamp") + keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url)) + else: + keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url)) + + prompt_image_models: list[RunwayAleph2PromptImage] = [] + if prompt_images is not None: + if len(prompt_images.items) > 5: + raise ValueError("Aleph2 supports at most 5 prompt images.") + for item in prompt_images.items: + image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png") + position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition + if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP: + _check_seconds(item.value, "Prompt image timestamp") + position = RunwayAleph2TimestampPosition(timestampSeconds=item.value) + else: + position = RunwayAleph2RelativePosition(positionPercentage=item.value) + prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url)) + + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"), + response_model=RunwayAleph2Response, + data=RunwayAleph2Request( + promptText=prompt, + videoUri=video_url, + seed=seed, + contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold), + keyframes=keyframe_models or None, + promptImage=prompt_image_models or None, + ), + ) + + final_response = await get_response(cls, initial_response.id) + if not final_response.output: + raise ValueError("Runway task succeeded but no video data found in response.") + + return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response))) + + class RunwayExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -527,6 +843,9 @@ class RunwayExtension(ComfyExtension): RunwayImageToVideoNodeGen3a, RunwayImageToVideoNodeGen4, RunwayTextToImageNode, + RunwayAleph2VideoToVideoNode, + RunwayAleph2KeyframeNode, + RunwayAleph2PromptImageNode, ] diff --git a/comfy_api_nodes/nodes_sonilo.py b/comfy_api_nodes/nodes_sonilo.py index 9ce896ed0..24a9a0b06 100644 --- a/comfy_api_nodes/nodes_sonilo.py +++ b/comfy_api_nodes/nodes_sonilo.py @@ -16,7 +16,7 @@ from comfy_api_nodes.util import ( ) from comfy_api_nodes.util._helpers import ( default_base_url, - get_auth_header, + get_comfy_api_headers, get_node_id, is_processing_interrupted, ) @@ -174,8 +174,7 @@ async def _stream_sonilo_music( """POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes.""" url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/")) - headers: dict[str, str] = {} - headers.update(get_auth_header(cls)) + headers = get_comfy_api_headers(cls) headers.update(endpoint.headers) node_id = get_node_id(cls) diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 648defe3d..83cf7b001 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -9,6 +9,7 @@ from io import BytesIO from yarl import URL from comfy.cli_args import args +from comfy.deploy_environment import get_deploy_environment from comfy.model_management import processing_interrupted from comfy_api.latest import IO @@ -35,6 +36,30 @@ def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: return {} +def get_usage_source(node_cls: type[IO.ComfyNode]) -> str: + """Source of the prompt that triggered this API node. + + Defaults to "comfyui-api" when the submitting client didn't identify itself, + i.e. a direct API call to this server. + """ + return node_cls.hidden.comfy_usage_source or "comfyui-api" + + +def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]: + """Common headers (auth, deploy environment, usage source) for Comfy API requests. + + Centralizes the shared header set so every Comfy API request sends a consistent + set and new shared headers only need to be added in one place. Intended for + relative/cloud URLs resolved against ``default_base_url()``; because the result + includes auth, callers must not attach it to arbitrary absolute/presigned URLs. + """ + return { + **get_auth_header(node_cls), + "Comfy-Env": get_deploy_environment(), + "Comfy-Usage-Source": get_usage_source(node_cls), + } + + def default_base_url() -> str: return getattr(args, "comfy_api_base", "https://api.comfy.org") diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 57c501724..adcde7bcb 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -19,12 +19,10 @@ from comfy import utils from comfy_api.latest import IO from server import PromptServer -from comfy.deploy_environment import get_deploy_environment - from . import request_logger from ._helpers import ( default_base_url, - get_auth_header, + get_comfy_api_headers, get_node_id, is_processing_interrupted, sleep_with_interrupt, @@ -645,8 +643,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? - payload_headers.update(get_auth_header(cfg.node_cls)) - payload_headers["Comfy-Env"] = get_deploy_environment() + payload_headers.update(get_comfy_api_headers(cfg.node_cls)) if cfg.endpoint.headers: payload_headers.update(cfg.endpoint.headers) diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index aa588d038..0ec3c6e66 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -17,7 +17,7 @@ from folder_paths import get_output_directory from . import request_logger from ._helpers import ( default_base_url, - get_auth_header, + get_comfy_api_headers, is_processing_interrupted, sleep_with_interrupt, to_aiohttp_url, @@ -64,7 +64,7 @@ async def download_url_to_bytesio( if cls is None: raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) - headers = get_auth_header(cls) + headers = get_comfy_api_headers(cls) while True: attempt += 1 diff --git a/comfy_execution/asset_enrichment.py b/comfy_execution/asset_enrichment.py new file mode 100644 index 000000000..38e9496a8 --- /dev/null +++ b/comfy_execution/asset_enrichment.py @@ -0,0 +1,66 @@ +"""Enrich executed-node output entries with asset id.""" +import logging +import os + + +def enrich_output_with_assets(output_ui: dict) -> dict: + """Register file-type output entries as assets and inject their ``id``. + + Runs at output-processing time, once per produced output, when + --enable-assets is set. Returns a new dict; entries without a resolvable + on-disk file path are left unchanged. Errors are caught per-entry so a + failure never blocks execution or the other entries. + """ + from comfy.cli_args import args + if not args.enable_assets: + return output_ui + + import folder_paths + from app.assets.services.ingest import register_file_in_place, DependencyMissingError + + enriched = {} + for key, entries in output_ui.items(): + if not isinstance(entries, list): + enriched[key] = entries + continue + new_entries = [] + for entry in entries: + if not isinstance(entry, dict) or "filename" not in entry or "type" not in entry: + new_entries.append(entry) + continue + try: + base = folder_paths.get_directory_by_type(entry["type"]) + if base is None: + new_entries.append(entry) + continue + base_abs = os.path.abspath(base) + abs_path = os.path.abspath(os.path.join(base_abs, entry.get("subfolder") or "", entry["filename"])) + try: + if os.path.commonpath([base_abs, abs_path]) != base_abs: + raise ValueError("escapes base") + except ValueError: + logging.warning("Asset enrichment skipped (path escapes base): %s", entry.get("filename")) + new_entries.append(entry) + continue + if not os.path.isfile(abs_path): + new_entries.append(entry) + continue + + # Register unconditionally: the file was just produced, and + # register_file_in_place re-hashes so an overwritten path can + # never carry a stale id. + result = register_file_in_place( + abs_path=abs_path, + name=entry["filename"], + tags=[entry["type"]], + ) + + entry = dict(entry) + entry["id"] = result.ref.id + except DependencyMissingError: + logging.warning("Asset enrichment skipped (blake3 not available): %s", entry.get("filename")) + except Exception: + logging.warning("Failed to enrich output entry with asset id: %s", entry.get("filename"), exc_info=True) + new_entries.append(entry) + enriched[key] = new_entries + return enriched diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index fcd7ef735..20ebae155 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -3,6 +3,7 @@ Job utilities for the /api/jobs endpoint. Provides normalization and helper functions for job status tracking. """ +import uuid from typing import Optional from comfy_api.internal import prune_dict @@ -19,6 +20,25 @@ class JobStatus: ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] +def validate_job_id(value) -> str: + """Validate a client-supplied job (prompt) id. + + Job ids must be UUIDs in the canonical lowercase hyphenated form. The id + is stored and compared verbatim everywhere downstream — history keys, + websocket events, and /interrupt matching — so accepting another spelling + would silently rewrite the client's id and then miss every exact-match + lookup. Rejecting loudly beats that. + + Returns the id unchanged. Raises ValueError when the value is not a + string in canonical UUID form. + """ + if not isinstance(value, str): + raise ValueError(f"job id must be a string, got {type(value).__name__}") + if str(uuid.UUID(value)) != value: + raise ValueError("job id must be a UUID in canonical lowercase hyphenated form") + return value + + # Media types that can be previewed in the frontend PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 532140be7..91ca01a23 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode): return IO.Schema( node_id="SaveAudio", search_aliases=["export flac"], - display_name="Save Audio (FLAC)", + display_name="Save Audio (FLAC) (DEPRECATED)", category="audio", essentials_category="Audio", inputs=[ @@ -166,6 +166,7 @@ class SaveAudio(IO.ComfyNode): IO.String.Input("filename_prefix", default="audio/ComfyUI"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_deprecated=True, is_output_node=True, outputs=[IO.Audio.Output("audio")] ) @@ -186,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode): return IO.Schema( node_id="SaveAudioMP3", search_aliases=["export mp3"], - display_name="Save Audio (MP3)", + display_name="Save Audio (MP3) (DEPRECATED)", category="audio", essentials_category="Audio", inputs=[ @@ -195,6 +196,7 @@ class SaveAudioMP3(IO.ComfyNode): IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_deprecated=True, is_output_node=True, outputs=[IO.Audio.Output("audio")] ) @@ -217,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode): return IO.Schema( node_id="SaveAudioOpus", search_aliases=["export opus"], - display_name="Save Audio (Opus)", + display_name="Save Audio (Opus) (DEPRECATED)", category="audio", inputs=[ IO.Audio.Input("audio"), @@ -225,6 +227,7 @@ class SaveAudioOpus(IO.ComfyNode): IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_deprecated=True, is_output_node=True, outputs=[IO.Audio.Output("audio")] ) @@ -241,6 +244,54 @@ class SaveAudioOpus(IO.ComfyNode): ) +class SaveAudioAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioAdvanced", + search_aliases=["save audio", "export audio", "output audio", "write audio", "flac", "mp3", "opus"], + display_name="Save Audio (Advanced)", + description="Saves the input audio to your ComfyUI output directory.", + category="audio", + inputs=[ + IO.Audio.Input("audio", tooltip="The audio to save."), + IO.String.Input( + "filename_prefix", + default="audio/ComfyUI", + tooltip=( + "The prefix for the file to save. May include formatting tokens " + "such as %date:yyyy-MM-dd%." + ), + ), + IO.DynamicCombo.Input( + "format", + options=[ + IO.DynamicCombo.Option("flac", []), + IO.DynamicCombo.Option("mp3", [ + IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), + ]), + IO.DynamicCombo.Option("opus", [ + IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), + ]), + ], + tooltip="The file format in which to save the audio.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, audio, filename_prefix: str, format: dict) -> IO.NodeOutput: + file_format = format.get("format", None) + quality = format.get("quality", None) + if quality: + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality) + else: + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format) + return IO.NodeOutput(ui=ui) + + class PreviewAudio(IO.ComfyNode): @classmethod def define_schema(cls): @@ -823,6 +874,7 @@ class AudioExtension(ComfyExtension): SaveAudio, SaveAudioMP3, SaveAudioOpus, + SaveAudioAdvanced, LoadAudio, PreviewAudio, ConditioningStableAudio, diff --git a/comfy_extras/nodes_bernini.py b/comfy_extras/nodes_bernini.py new file mode 100644 index 000000000..227fa5753 --- /dev/null +++ b/comfy_extras/nodes_bernini.py @@ -0,0 +1,115 @@ +import torch +from typing_extensions import override + +import comfy.model_management +import comfy.utils +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +def _resize_long_edge(image, max_size, stride=16): + """Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride`""" + h, w = image.shape[1], image.shape[2] + scale = min(max_size / max(h, w), 1.0) + nh = max(stride, round(h * scale / stride) * stride) + nw = max(stride, round(w * scale / stride) * stride) + return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1) + + +class BerniniConditioning(io.ComfyNode): + """Bernini in-context conditioning for a Wan2.2-A14B model. + + Attaches the VAE-encoded source video / reference images to the conditioning + source video first, then each reference image + + The task is inferred from which inputs are connected: + (nothing) -> t2v (text-to-video) + source_video -> v2v (video-to-video) + source_video + ref_images -> rv2v (reference-guided video editing) + ref_images only -> r2v (reference-to-video) + source_video + ref_video -> ads2v (insert image/video into video) + + source_video is the edit base / canvas (resized to width x height). + reference_video is moving content to composite in. + Streams are ordered source_video, reference_video, then reference_images -> source_id (1, 2, 3, ...). + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="BerniniConditioning", + display_name="Bernini Conditioning", + category="conditioning/video_models", + description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)." + "Reference images injected as in-context tokens (r2v, rv2v) are encoded independently at their own native aspect ratio (long edge capped at ref_max_size)", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=8192, step=16), + io.Int.Input("height", default=480, min=16, max=8192, step=16), + io.Int.Input("length", default=81, min=1, max=8192, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("source_video", optional=True, tooltip=( + "Source video to edit or restyle (v2v, rv2v). Resized to width/height and trimmed to length.")), + io.Image.Input("reference_video", optional=True, tooltip=( + "Video to insert into the source video (ads2v).")), + io.Autogrow.Input("reference_images", optional=True, + template=io.Autogrow.TemplatePrefix( + input=io.Image.Input("reference_image", tooltip=( + "Reference image injected as an in-context token (r2v, rv2v).")), + prefix="reference_image_", min=0, max=8)), + io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=( + "Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, + source_video=None, reference_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], + device=comfy.model_management.intermediate_device()) + + # source_video (1), reference_video (2), reference_images (3, 4, ...). + context = [] + if source_video is not None: + vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1) + context.append(vae.encode(vid[:, :, :, :3])) + + if reference_video is not None: + ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect + context.append(vae.encode(ref_vid[:, :, :, :3])) + + # reference_images is an autogrow dict {reference_image_0: IMAGE, ...}; each slot is a + # separate stream at its own native aspect (a multi-image batch in one slot -> one stream per frame). + if reference_images: + for name in sorted(reference_images): + imgs = reference_images[name] + if imgs is None: + continue + for i in range(imgs.shape[0]): + img = _resize_long_edge(imgs[i:i + 1], ref_max_size) # native aspect per ref + context.append(vae.encode(img[:, :, :, :3])) + + if context: + positive = node_helpers.conditioning_set_values(positive, {"context_latents": context}) + negative = node_helpers.conditioning_set_values(negative, {"context_latents": context}) + + return io.NodeOutput(positive, negative, {"samples": latent}) + + +class BerniniExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + BerniniConditioning, + ] + + +async def comfy_entrypoint() -> BerniniExtension: + return BerniniExtension() diff --git a/comfy_extras/nodes_bg_removal.py b/comfy_extras/nodes_bg_removal.py index 9dc9ad854..c7b33a821 100644 --- a/comfy_extras/nodes_bg_removal.py +++ b/comfy_extras/nodes_bg_removal.py @@ -36,15 +36,15 @@ class RemoveBackground(IO.ComfyNode): category="image/background removal", description="Generates a foreground mask to remove the background from an image using a background removal model.", inputs=[ - IO.Image.Input("image", tooltip="Input image to remove the background from"), - IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask") + IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask"), + IO.Image.Input("image", tooltip="Input image to remove the background from") ], outputs=[ IO.Mask.Output("mask", tooltip="Generated foreground mask") ] ) @classmethod - def execute(cls, image, bg_removal_model): + def execute(cls, bg_removal_model, image): mask = bg_removal_model.encode_image(image) return IO.NodeOutput(mask) diff --git a/comfy_extras/nodes_color.py b/comfy_extras/nodes_color.py index 01a05035e..688254e4e 100644 --- a/comfy_extras/nodes_color.py +++ b/comfy_extras/nodes_color.py @@ -7,29 +7,29 @@ class ColorToRGBInt(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="ColorToRGBInt", - display_name="Color to RGB Int", + display_name="Color Picker", category="utilities", - description="Convert a color to a RGB integer value.", + description="Return a color RGB integer value and hexadecimal representation.", inputs=[ io.Color.Input("color"), ], outputs=[ io.Int.Output(display_name="rgb_int"), + io.Color.Output(display_name="hex") ], ) @classmethod - def execute( - cls, - color: str, - ) -> io.NodeOutput: + def execute(cls, color: str) -> io.NodeOutput: # expect format #RRGGBB if len(color) != 7 or color[0] != "#": raise ValueError("Color must be in format #RRGGBB") r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) - return io.NodeOutput(r * 256 * 256 + g * 256 + b) + + rgb_int = r * 256 * 256 + g * 256 + b + return io.NodeOutput(rgb_int, color) class ColorExtension(ComfyExtension): diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 2f4ff1f70..3e97084a4 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -933,9 +933,10 @@ class Guider_DualModel(comfy.samplers.CFGGuider): def predict_noise(self, x, timestep, model_options={}, seed=None): positive = self.conds.get("positive", None) - if self.uncond_inner is None: # cfg == 1 or no negative -> single model, cond only - return comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0] cond = comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0] + # uncond model not loaded (base cfg==1/no negative), or cfg driven to 1.0 this step -> single model, cond only + if self.uncond_inner is None or (math.isclose(self.cfg, 1.0) and not model_options.get("disable_cfg1_optimization", False)): + return cond uncond_model_options = model_options if "multigpu_clones" in model_options: # TODO: support multigpu instead of just running uncond on a single GPU @@ -1140,7 +1141,7 @@ class CFGOverride(io.ComfyNode): return io.Schema( node_id="CFGOverride", display_name="CFG Override", - description="Override cfg to a fixed value over a [start, end] percent slice of the steps. " + description="Override cfg to a fixed value over a [start, end] percent (sigma) range. " "With multiple overrides, the one nearest the sampler wins on overlap.", category="sampling/custom_sampling", inputs=[ diff --git a/comfy_extras/nodes_depth_anything_3.py b/comfy_extras/nodes_depth_anything_3.py new file mode 100644 index 000000000..020112515 --- /dev/null +++ b/comfy_extras/nodes_depth_anything_3.py @@ -0,0 +1,681 @@ +"""ComfyUI nodes for Depth Anything 3. +Model capability matrix: + +Variant head_type has_sky has_conf cam_dec +DA3-Small dualdpt False True yes +DA3-Base dualdpt False True yes +DA3-Mono-Large dpt True False no +DA3-Metric-Large dpt True False no (raw output is metres) +""" + +from __future__ import annotations + +import logging +from typing_extensions import override + +import torch + +import comfy.model_management as mm +import comfy.sd +import folder_paths +from comfy.ldm.colormap import turbo as _turbo +from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess +from comfy_api.latest import ComfyExtension, Types, io +from comfy.ldm.moge.geometry import triangulate_grid_mesh + +DA3ModelType = io.Custom("DA3_MODEL") +DA3Geometry = io.Custom("DA3_GEOMETRY") +DA3PointCloud = io.Custom("DA3_POINT_CLOUD") + +# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them): +# +# Per-frame tensors - B = batch size in mono mode; B = S (number of views) in multi-view mode. +# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention) +# "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present) +# "mode": str -- "mono" or "multiview" (always present) +# "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only) +# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only) +# +# Multi-view only - S = number of views; the leading 1 is the scene dimension from the model. +# "extrinsics": torch.Tensor (1, S, 3, 4) -- world-to-camera [R|t] matrices +# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics +# +# DA3_POINT_CLOUD is a dict: +# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back) +# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None +# "confidence": torch.Tensor (N,) -- raw confidence per point, or None + + +def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor: + """Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space.""" + H, W = depth.shape + u = torch.arange(W, dtype=torch.float32, device=depth.device) + v = torch.arange(H, dtype=torch.float32, device=depth.device) + u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W) + pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3) + rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix) + return rays * depth.unsqueeze(-1) # (H, W, 3) + + +def _da3_default_K(H: int, W: int) -> torch.Tensor: + """Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry).""" + fx = fy = float(W) * 0.7 + return torch.tensor([[fx, 0.0, (W - 1) / 2.0], + [0.0, fy, (H - 1) / 2.0], + [0.0, 0.0, 1.0]], dtype=torch.float32) + + +def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor: + """Return pixel-space K for batch element b, falling back to a default estimate.""" + if "intrinsics" in geometry: + # shape (1, S, 3, 3) - leading scene dimension from the multiview head + return geometry["intrinsics"][0, b].float() + logging.getLogger("comfy").warning( + "DA3_GEOMETRY has no intrinsics (mono-mode model). " + "Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate." + ) + return _da3_default_K(H, W) + + +def _da3_get_extrinsic(geometry: dict, b: int) -> torch.Tensor | None: + """Return the world-to-camera extrinsic for batch element b, or None in mono mode. + + The model outputs (1, S, 3, 4) [R|t] matrices; the fallback identity is (4, 4). + _da3_apply_extrinsic handles both shapes via [:3, :3] / [:3, 3] slicing. + """ + if "extrinsics" not in geometry: + return None + return geometry["extrinsics"][0, b].float() + + +def _da3_apply_extrinsic(points_cam: torch.Tensor, E: torch.Tensor) -> torch.Tensor: + """Transform (H,W,3) OpenCV camera-space points to world space.""" + E = E.to(points_cam.device).float() + if not torch.isfinite(E).all(): + logging.getLogger("comfy").warning( + "DA3 extrinsic matrix contains non-finite values (pose estimation may have failed). " + "Falling back to camera-space coordinates." + ) + return points_cam + H, W, _ = points_cam.shape + R = E[:3, :3] # (3, 3) rotation + t = E[:3, 3] # (3,) translation + R_inv = R.T # rotation inverse = transpose for orthogonal R + t_inv = -(R_inv @ t) # (3,) + pts = points_cam.reshape(-1, 3) # (N, 3) + pts_world = pts @ R_inv.T + t_inv # (N, 3) + return pts_world.reshape(H, W, 3) + + +def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor: + """Map raw confidence to [0, 1] per image.""" + B = conf.shape[0] + out = [] + for i in range(B): + c = conf[i] + c_min, c_max = c.min(), c.max() + out.append((c - c_min) / (c_max - c_min) if c_max > c_min else torch.ones_like(c)) + return torch.stack(out, dim=0) + + +def _da3_build_mask(geometry: dict, b: int, H: int, W: int, confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor: + """Build (H,W) bool keep-mask from sky probability and confidence.""" + mask = torch.ones(H, W, dtype=torch.bool) + if use_sky_mask and "sky" in geometry: + mask = mask & (geometry["sky"][b] < 0.5) + if "confidence" in geometry and confidence_threshold > 0.0: + conf_norm = _normalize_confidence(geometry["confidence"][b:b + 1])[0] + mask = mask & (conf_norm >= confidence_threshold) + return mask + + +class LoadDA3Model(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadDA3Model", + display_name="Load Depth Anything 3", + category="model/loaders", + inputs=[ + io.Combo.Input( + "model_name", + options=folder_paths.get_filename_list("geometry_estimation"), + ), + io.Combo.Input( + "weight_dtype", + options=["default", "fp16", "bf16", "fp32"], + default="default", + ), + ], + outputs=[DA3ModelType.Output()], + ) + + @classmethod + def execute(cls, model_name, weight_dtype) -> io.NodeOutput: + model_options = {} + if weight_dtype == "fp16": + model_options["dtype"] = torch.float16 + elif weight_dtype == "bf16": + model_options["dtype"] = torch.bfloat16 + elif weight_dtype == "fp32": + model_options["dtype"] = torch.float32 + + path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name) + model = comfy.sd.load_diffusion_model(path, model_options=model_options) + return io.NodeOutput(model) + + +def _run_da3(model_patcher, image: torch.Tensor, process_res: int, method: str = "upper_bound_resize"): + """Run DA3 on (B,H,W,3), returns depth/conf/sky at original resolution (or None).""" + assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + + B, H, W, _ = image.shape + mm.load_model_gpu(model_patcher) + diffusion = model_patcher.model.diffusion_model + device = mm.get_torch_device() + dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32 + + depths, confs, skies = [], [], [] + for i in range(B): + single = image[i:i + 1].to(device) + x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method) + x = x.to(dtype=dtype) + with torch.no_grad(): + out = diffusion(x) + + depth_lr = out["depth"] + depth_full = torch.nn.functional.interpolate( + depth_lr.unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + depths.append(depth_full) + + if "depth_conf" in out: + conf_full = torch.nn.functional.interpolate( + out["depth_conf"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + confs.append(conf_full) + if "sky" in out: + sky_full = torch.nn.functional.interpolate( + out["sky"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + skies.append(sky_full) + + depth = torch.cat(depths, dim=0) + confidence = torch.cat(confs, dim=0) if confs else None + sky = torch.cat(skies, dim=0) if skies else None + return depth, confidence, sky + + +class DA3Inference(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3Inference", + search_aliases=["depth", "geometry", "da3", "depth anything", "monocular", "pointmap", "sky", "3d", "metric depth", "disparity"], + display_name="Run Depth Anything 3", + category="image/geometry estimation", + description="Run Depth Anything 3 on an image. In multi-view mode each image is treated as a separate view of the same scene.", + inputs=[ + DA3ModelType.Input("da3_model"), + io.Image.Input("image"), + io.Int.Input("resolution", default=504, min=140, max=2520, step=14, + tooltip="Resolution the model runs at (longest side, multiple of 14).\n" + "Lower = faster / less VRAM.\n" + "Higher = more detail.\n" + "Output is upsampled back to the original size."), + io.Combo.Input("resize_method", options=["upper_bound_resize", "lower_bound_resize"], default="upper_bound_resize", + tooltip="upper_bound_resize: scale so the longest side = resolution (caps memory, default).\n" + "lower_bound_resize: scale so the shortest side = resolution (preserves more detail on tall/wide images, uses more memory)."), + io.DynamicCombo.Input("mode", tooltip="mono: single view image (works with any model variant).\n" + "multiview: all images processed together for geometric consistency + camera pose (for Small/Base models only).", + options=[ + io.DynamicCombo.Option("mono", []), + io.DynamicCombo.Option("multiview", [ + io.Combo.Input("ref_view_strategy", options=["saddle_balanced", "saddle_sim_range", "first", "middle"], default="saddle_balanced", + tooltip="Which view acts as the geometric anchor.\n" + "- saddle_balanced: the view most 'average' across all others (best general choice).\n" + "- saddle_sim_range: the view most visually distinct from the others.\n" + "- first / middle: fixed positional picks."), + io.Combo.Input("pose_method", options=["cam_dec", "ray_pose"], default="cam_dec", + tooltip="How the camera field-of-view is estimated (for Small/Base models only).\n" + "- cam_dec: learned from image features.\n" + "- ray_pose: derived geometrically from the model's 3D ray output.\n" + "Affects perspective correctness of the 3D output. Try both if results look distorted."), + ]), + ]), + ], + outputs=[ + DA3Geometry.Output("da3_geometry", tooltip="Dictionary of non-normalized tensors.\n" + "Always has the keys: depth, image, mode.\n" + "Optional keys: sky (for Mono/Metric), confidence (for Small/Base), extrinsics + intrinsics (for multi-view)."), + ], + ) + + @classmethod + def execute(cls, da3_model, image, resolution, resize_method, mode) -> io.NodeOutput: + mode_val = mode["mode"] # "mono" or "multiview" + + if mode_val == "mono": + return cls._execute_mono(da3_model, image, resolution, resize_method) + + # Capability checks for multi-view mode. + diffusion = da3_model.model.diffusion_model + pose_method = mode["pose_method"] + ref_view_strategy = mode["ref_view_strategy"] + + has_cam_dec = diffusion.cam_dec is not None + has_dualdpt = diffusion.head_type == "dualdpt" + + if not has_cam_dec and not has_dualdpt: + raise ValueError( + "multi-view mode requires Small or Base model. The loaded model " + f"(head_type='{diffusion.head_type}') does not support cross-view " + "attention or camera pose estimation. Switch mode to 'mono', or " + "load Small or Base model for mult-view." + ) + + if pose_method == "cam_dec" and not has_cam_dec: + raise ValueError( + "pose_method='cam_dec' requires a camera decoder, but the loaded " + f"model (head_type='{diffusion.head_type}') does not have one. " + "Use pose_method='ray_pose' instead." + ) + if pose_method == "ray_pose" and not has_dualdpt: + raise ValueError( + "pose_method='ray_pose' requires a DualDPT head, but the loaded " + f"model has a '{diffusion.head_type}' head. " + "Use pose_method='cam_dec' instead." + ) + + return cls._execute_multiview( + da3_model, image, resolution, resize_method, + ref_view_strategy, pose_method, + ) + + @classmethod + def _execute_mono(cls, model, image, resolution, resize_method) -> io.NodeOutput: + depth, confidence, sky = _run_da3(model, image, resolution, method=resize_method) + + geometry: dict = { + "depth": depth.contiguous(), + "image": image[..., :3].cpu(), + "mode": "mono", + } + if sky is not None: + geometry["sky"] = sky.contiguous() + if confidence is not None: + geometry["confidence"] = confidence.contiguous() + return io.NodeOutput(geometry) + + @classmethod + def _execute_multiview(cls, model, image, resolution, resize_method, ref_view_strategy, pose_method) -> io.NodeOutput: + assert image.ndim == 4 and image.shape[-1] == 3, \ + f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}" + S, H, W, _ = image.shape + + mm.load_model_gpu(model) + diffusion = model.model.diffusion_model + device = mm.get_torch_device() + dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32 + + # All views in a single forward pass: (1, S, 3, H', W'). + x = image.to(device) + x = da3_preprocess.preprocess_image(x, process_res=resolution, method=resize_method) + x = x.to(dtype=dtype).unsqueeze(0) + + use_ray_pose = (pose_method == "ray_pose") + with torch.no_grad(): + out = diffusion(x, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy) + + depth = torch.nn.functional.interpolate( + out["depth"].float().unsqueeze(1), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + + sky = None + if "sky" in out: + sky = torch.nn.functional.interpolate( + out["sky"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + + if "extrinsics" in out and "intrinsics" in out: + extrinsics = out["extrinsics"].float().cpu() + intrinsics = out["intrinsics"].float().cpu() + else: + extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone() + intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone() + + geometry: dict = { + "depth": depth.contiguous(), + "image": image[..., :3].cpu(), + "mode": "multiview", + "extrinsics": extrinsics.contiguous(), + "intrinsics": intrinsics.contiguous(), + } + if sky is not None: + geometry["sky"] = sky.contiguous() + if "depth_conf" in out: + conf = torch.nn.functional.interpolate( + out["depth_conf"].unsqueeze(1).float(), size=(H, W), + mode="bilinear", align_corners=False, + ).squeeze(1).cpu() + geometry["confidence"] = conf.contiguous() + return io.NodeOutput(geometry) + + +class DA3Render(io.ComfyNode): + """Render a visualization from a DA3_GEOMETRY packet.""" + + _DEPTH_RENDER_INPUTS = [ + io.Combo.Input("normalization", + options=["v2_style", "min_max", "raw"], + default="v2_style", + tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n" + "- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n" + "- raw: no scaling,preserves metric units for Metric model."), + io.Boolean.Input("apply_sky_clip", default=False, + tooltip="Clip sky-region depth to the 99th percentile of foreground depth before normalisation. " + "Requires a sky key in the da3_geometry input (for Mono/Metric models only)."), + ] + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3Render", + display_name="Render Depth Anything 3", + category="image/geometry estimation", + description="Render a depth map, confidence map, or sky mask from Depth Anything 3 geometry data.", + inputs=[ + DA3Geometry.Input("da3_geometry"), + io.DynamicCombo.Input("output", + tooltip="- depth: normalised greyscale depth image.\n" + "- depth_colored: depth mapped through the Turbo colormap.\n" + "- sky_mask: sky probability in [0, 1] (for Mono/Metric models only).\n" + "- confidence: normalised depth confidence (for Small/Base models only).", + options=[ + io.DynamicCombo.Option("depth", cls._DEPTH_RENDER_INPUTS), + io.DynamicCombo.Option("depth_colored", cls._DEPTH_RENDER_INPUTS), + io.DynamicCombo.Option("sky_mask", [ + io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the sky mask."), + ]), + io.DynamicCombo.Option("confidence", [ + io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the confidence map."), + ]), + ]), + ], + outputs=[io.Image.Output()], + ) + + @classmethod + def execute(cls, da3_geometry, output) -> io.NodeOutput: + output_val = output["output"] + + if output_val in ("depth", "depth_colored"): + normalization = output["normalization"] + apply_sky_clip = output["apply_sky_clip"] + if apply_sky_clip and "sky" not in da3_geometry: + raise ValueError( + "apply_sky_clip=True requires a sky tensor in the da3_geometry input, but none is present. " + "Run with Mono/Metric models or set apply_sky_clip=False." + ) + depth = da3_geometry["depth"] + sky = da3_geometry.get("sky") + if apply_sky_clip and sky is not None: + depth = torch.stack([ + da3_preprocess.apply_sky_aware_clip(depth[i], sky[i]) + for i in range(depth.shape[0]) + ], dim=0) + grey = cls._depth_to_image(depth, sky, normalization) # (B,H,W,3) greyscale + result = _turbo(grey[..., 0]) if output_val == "depth_colored" else grey + + elif output_val == "sky_mask": + if "sky" not in da3_geometry: + raise ValueError("geometry has no sky output; run with Mono/Metric models.") + sky = da3_geometry["sky"] + if output["colored"]: + result = _turbo(sky) + else: + result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous() + + elif output_val == "confidence": + if "confidence" not in da3_geometry: + raise ValueError("da3_geometry has no confidence output; run with Small/Base models.") + conf = _normalize_confidence(da3_geometry["confidence"]) + if output["colored"]: + result = _turbo(conf) + else: + result = conf.unsqueeze(-1).expand(*conf.shape, 3).contiguous() + + else: + raise ValueError(f"Unknown output mode: {output_val}") + + return io.NodeOutput(result.float()) + + @staticmethod + def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None, normalization: str) -> torch.Tensor: + """Normalise depth and pack as an (B,H,W,3) image tensor.""" + + N = depth.shape[0] + if normalization == "v2_style": + norm = torch.stack([ + da3_preprocess.normalize_depth_v2_style( + depth[i], sky_for_norm[i] if sky_for_norm is not None else None) + for i in range(N) + ], dim=0) + elif normalization == "min_max": + norm = da3_preprocess.normalize_depth_min_max(depth) + else: + norm = depth + + out = norm.unsqueeze(-1).repeat(1, 1, 1, 3) + if normalization != "raw": + out = out.clamp(0.0, 1.0) + return out.contiguous() + + +class DA3GeometryToMesh(io.ComfyNode): + """Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3GeometryToMesh", + search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"], + display_name="Convert DA3 Geometry to Mesh", + category="image/geometry estimation", + description="Convert a depth map into a triangulated 3D mesh.", + inputs=[ + DA3Geometry.Input("da3_geometry"), + io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert. Per-image vertex counts differ so batches cannot be stacked."), + io.Int.Input("decimation", default=1, min=1, max=8, tooltip="Vertex stride. 1 = full resolution, 2 = half, etc."), + io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01, tooltip="Drop triangles whose 3x3 depth span exceeds this fraction. 0 = off."), + io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01, + tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all, 1 = keep only the single most confident pixel). " + "Used when the geometry has a confidence map (Small/Base models)."), + io.Boolean.Input("use_sky_mask", default=True, tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. Used when the geometry has a sky map (Mono/Metric models)."), + io.Boolean.Input("texture", default=True, tooltip="Use the source image as a base color texture."), + ], + outputs=[io.Mesh.Output()], + ) + + @classmethod + def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold, confidence_threshold, use_sky_mask, texture) -> io.NodeOutput: + depth_all = da3_geometry["depth"] # (B, H, W) + B = depth_all.shape[0] + if batch_index >= B: + raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.") + + depth = depth_all[batch_index] # (H, W) + H, W = depth.shape + + # NaN/inf depth would propagate silently through unproject and produce an + # empty mesh; replace them with 0 here so those pixels are later excluded + # by the isfinite check inside triangulate_grid_mesh. + depth = depth.clone() + n_bad = (~torch.isfinite(depth)).sum().item() + if n_bad: + logging.getLogger("comfy").warning( + f"DA3GeometryToMesh: depth[{batch_index}] has {n_bad} non-finite pixels " + f"({100*n_bad/(H*W):.1f}%) - zeroed before unproject." + ) + depth[~torch.isfinite(depth)] = 0.0 + logging.getLogger("comfy").debug( + f"DA3GeometryToMesh: depth[{batch_index}] range " + f"[{depth.min():.4g}, {depth.max():.4g}], mean={depth.mean():.4g}" + ) + + K = _da3_get_K(da3_geometry, batch_index, H, W) + points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV camera space + + # Apply world-to-camera inverse so multi-view frames share a common world frame. + E = _da3_get_extrinsic(da3_geometry, batch_index) + if E is not None: + points = _da3_apply_extrinsic(points, E) + + # Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them. + mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask) + # Also exclude pixels where depth was invalid. + mask = mask & (depth_all[batch_index] > 0) & torch.isfinite(depth_all[batch_index]) + points = points.clone() + points[~mask] = float('inf') + + verts, faces, uvs = triangulate_grid_mesh( + points, + decimation=decimation, + discontinuity_threshold=discontinuity_threshold, + depth=depth, + ) + if verts.shape[0] == 0 or faces.shape[0] == 0: + raise ValueError( + "DA3GeometryToMesh produced an empty mesh. " + "Try raising discontinuity_threshold, lowering confidence_threshold, " + "or disabling use_sky_mask." + ) + + # OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back). + # Same transform as MoGePointMapToMesh perspective branch. + verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype) + faces = faces[:, [0, 2, 1]].contiguous() + + tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None + mesh = Types.MESH( + vertices=verts.unsqueeze(0), + faces=faces.unsqueeze(0), + uvs=uvs.unsqueeze(0), + texture=tex, + ) + return io.NodeOutput(mesh) + + +class DA3GeometryToPointCloud(io.ComfyNode): + """Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DA3GeometryToPointCloud", + search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"], + display_name="Convert DA3 Geometry to Point Cloud", + category="image/geometry estimation", + description="Convert a depth map into a 3D point cloud.", + inputs=[ + DA3Geometry.Input("da3_geometry"), + io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert."), + io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01, + tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all). Used when the geometry has a confidence map (Small/Base models)."), + io.Boolean.Input("use_sky_mask", default=True, + tooltip="Exclude sky-probability pixels (sky >= 0.5). Used when the geometry has a sky map (Mono/Metric models)."), + io.Int.Input("downsample", default=1, min=1, max=16, + tooltip="Take every Nth pixel (1 = full resolution). Higher values give fewer points and faster processing."), + ], + # TODO: add a proper PointCloud output type + outputs=[DA3PointCloud.Output(display_name="point_cloud")], + ) + + @classmethod + def execute(cls, da3_geometry, batch_index, confidence_threshold, use_sky_mask, downsample) -> io.NodeOutput: + depth_all = da3_geometry["depth"] # (B, H, W) + B = depth_all.shape[0] + if batch_index >= B: + raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.") + + depth = depth_all[batch_index].clone() # (H, W) + depth[~torch.isfinite(depth)] = 0.0 + H, W = depth.shape + + K = _da3_get_K(da3_geometry, batch_index, H, W) + + if downsample > 1: + depth = depth[::downsample, ::downsample].contiguous() + # Scale intrinsics to the downsampled grid. + K = K.clone() + K[0, :] /= downsample + K[1, :] /= downsample + + H_ds, W_ds = depth.shape + points = _da3_unproject(depth, K) # (H_ds, W_ds, 3) in OpenCV camera space + + # Apply world-to-camera inverse so multi-view frames share a common world frame. + E = _da3_get_extrinsic(da3_geometry, batch_index) + if E is not None: + points = _da3_apply_extrinsic(points, E) + + # Rebuild mask at downsampled resolution. + mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask) + if downsample > 1: + mask = mask[::downsample, ::downsample] + + mask = mask & torch.isfinite(depth) + + # OpenCV → glTF: flip Y and Z. + points_gltf = points.clone() + points_gltf[..., 1] *= -1.0 + points_gltf[..., 2] *= -1.0 + + pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)] + + colors_flat = None + if "image" in da3_geometry: + img = da3_geometry["image"][batch_index] # (H, W, 3) + if downsample > 1: + img = img[::downsample, ::downsample] + colors_flat = img.reshape(-1, 3)[mask.reshape(-1)] + + conf_flat = None + if "confidence" in da3_geometry: + conf = da3_geometry["confidence"][batch_index] # (H, W) + if downsample > 1: + conf = conf[::downsample, ::downsample] + conf_flat = conf.reshape(-1)[mask.reshape(-1)] + + if pts_flat.shape[0] == 0: + raise ValueError( + "DA3GeometryToPointCloud produced zero points after filtering. " + "Try lowering confidence_threshold or disabling use_sky_mask." + ) + + return io.NodeOutput({ + "points": pts_flat, + "colors": colors_flat, + "confidence": conf_flat, + }) + + +class DA3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LoadDA3Model, + DA3Inference, + DA3Render, + DA3GeometryToMesh, + # DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type + ] + + +async def comfy_entrypoint() -> DA3Extension: + return DA3Extension() diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index afc663b22..ef1757ae5 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -245,6 +245,11 @@ class KV_Attn_Input: cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"]) if cache_key in self.cache: kk, vv = self.cache[cache_key] + + # Fix batch size changing. + kk = comfy.utils.repeat_to_batch_size(kk, k.shape[0]) + vv = comfy.utils.repeat_to_batch_size(vv, v.shape[0]) + self.set_cache = False return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)} diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py index 2ba3a3820..116c14fde 100644 --- a/comfy_extras/nodes_gaussian_splat.py +++ b/comfy_extras/nodes_gaussian_splat.py @@ -488,7 +488,7 @@ class SplatToFile3D(IO.ComfyNode): "spz: Niantic gzip-compressed (~10x smaller), base color only " ), ], - outputs=[IO.File3DAny.Output(display_name="model_3d")], + outputs=[IO.File3DSplatAny.Output(display_name="model_3d")], ) @classmethod @@ -516,7 +516,7 @@ class File3DToSplat(IO.ComfyNode): inputs=[ IO.MultiType.Input( IO.File3DAny.Input("model_3d"), - types=[IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ], + types=[IO.File3DSplatAny, IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ], tooltip="A gaussian splat 3D file", ), ], diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index b339dc4ff..455897859 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -51,6 +51,14 @@ class Load3D(IO.ComfyNode): ], ) + @classmethod + def validate_inputs(cls, model_file, **kwargs) -> bool | str: + if not model_file or model_file == "none": + return True + if not folder_paths.exists_annotated_filepath(model_file): + return f"Invalid 3D model file: {model_file}" + return True + @classmethod def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput: image_path = folder_paths.get_annotated_filepath(image['image']) @@ -136,7 +144,7 @@ class Preview3DAdvanced(IO.ComfyNode): is_output_node=True, inputs=[ IO.MultiType.Input( - "model_file", + "model_3d", types=[ IO.File3DGLB, IO.File3DGLTF, @@ -148,34 +156,161 @@ class Preview3DAdvanced(IO.ComfyNode): ], tooltip="3D model file from an upstream 3D node.", ), - IO.Load3D.Input("image"), - IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), + IO.Load3D.Input("viewport_state"), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), IO.Int.Input("width", default=1024, min=1, max=4096, step=1), IO.Int.Input("height", default=1024, min=1, max=4096, step=1), ], outputs=[ - IO.File3DAny.Output(display_name="model_file"), - IO.Load3DCamera.Output(display_name="camera_info"), + IO.File3DAny.Output(display_name="model_3d"), IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), IO.Int.Output(display_name="width"), IO.Int.Output(display_name="height"), ], ) @classmethod - def execute(cls, model_file: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput: - filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_file.format}" - model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_3d.format}" + model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename)) camera_info_input = kwargs.get("camera_info", None) - camera_info = camera_info_input if camera_info_input is not None else image['camera_info'] + camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info'] model_3d_info_input = kwargs.get("model_3d_info", None) - model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', []) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) return IO.NodeOutput( - model_file, - camera_info, + model_3d, model_3d_info, + camera_info, + width, + height, + ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), + ) + + +class PreviewGaussianSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewGaussianSplat", + display_name="Preview Splat", + category="3d", + is_experimental=True, + is_output_node=True, + search_aliases=[ + "view splat", + "view gaussian", + "view gaussian splat", + "preview gaussian", + "preview gaussian splat", + "view 3dgs", + "preview 3dgs", + "preview ply", + "preview spz", + "preview splat", + "preview ksplat", + ], + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[ + IO.File3DSplatAny, + IO.File3DPLY, + IO.File3DSPLAT, + IO.File3DSPZ, + IO.File3DKSPLAT, + ], + tooltip="A gaussian splat 3D file.", + ), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), + IO.Load3D.Input("viewport_state"), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DSplatAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview_splat_{uuid.uuid4().hex}.{model_3d.format}" + model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename)) + + camera_info_input = kwargs.get("camera_info", None) + camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info'] + model_3d_info_input = kwargs.get("model_3d_info", None) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) + return IO.NodeOutput( + model_3d, + model_3d_info, + camera_info, + width, + height, + ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), + ) + + +class PreviewPointCloud(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewPointCloud", + display_name="Preview Point Cloud", + category="3d", + is_experimental=True, + is_output_node=True, + search_aliases=[ + "view point cloud", + "view pointcloud", + "preview point cloud", + "preview pointcloud", + "preview ply", + ], + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[ + IO.File3DPointCloudAny, + IO.File3DPLY, + ], + tooltip="Point cloud file (.ply)", + ), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), + IO.Load3D.Input("viewport_state"), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DPointCloudAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview_pointcloud_{uuid.uuid4().hex}.{model_3d.format}" + model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename)) + + camera_info_input = kwargs.get("camera_info", None) + camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info'] + model_3d_info_input = kwargs.get("model_3d_info", None) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) + return IO.NodeOutput( + model_3d, + model_3d_info, + camera_info, width, height, ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), @@ -189,6 +324,8 @@ class Load3DExtension(ComfyExtension): Load3D, Preview3D, Preview3DAdvanced, + PreviewGaussianSplat, + PreviewPointCloud, ] diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py index 422949531..a63f0414b 100644 --- a/comfy_extras/nodes_moge.py +++ b/comfy_extras/nodes_moge.py @@ -8,6 +8,7 @@ import folder_paths from comfy_api.latest import ComfyExtension, Types, io from typing_extensions import override +from comfy.ldm.colormap import turbo as _turbo from comfy.ldm.moge.model import MoGeModel from comfy.ldm.moge.geometry import triangulate_grid_mesh from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid @@ -27,19 +28,6 @@ MoGeGeometry = io.Custom("MOGE_GEOMETRY") # "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present) -def _turbo(x: torch.Tensor) -> torch.Tensor: - """Anton Mikhailov polynomial approximation of the turbo colormap.""" - x = x.clamp(0.0, 1.0) - x2 = x * x - x3 = x2 * x - x4 = x2 * x2 - x5 = x4 * x - r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5 - g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5 - b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5 - return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0) - - def _normals_from_points(points: torch.Tensor) -> torch.Tensor: """Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback).""" finite = torch.isfinite(points).all(dim=-1) diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py index dc405291c..083e47ae4 100644 --- a/comfy_extras/nodes_resolution.py +++ b/comfy_extras/nodes_resolution.py @@ -6,24 +6,24 @@ from comfy_api.latest import ComfyExtension, io class AspectRatio(str, Enum): SQUARE = "1:1 (Square)" + PHOTO_V = "2:3 (Portrait Photo)" PHOTO_H = "3:2 (Photo)" + STANDARD_V = "3:4 (Portrait Standard)" STANDARD_H = "4:3 (Standard)" + WIDESCREEN_V = "9:16 (Portrait Widescreen)" WIDESCREEN_H = "16:9 (Widescreen)" ULTRAWIDE_H = "21:9 (Ultrawide)" - PHOTO_V = "2:3 (Portrait Photo)" - STANDARD_V = "3:4 (Portrait Standard)" - WIDESCREEN_V = "9:16 (Portrait Widescreen)" ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = { AspectRatio.SQUARE: (1, 1), + AspectRatio.PHOTO_V: (2, 3), AspectRatio.PHOTO_H: (3, 2), + AspectRatio.STANDARD_V: (3, 4), AspectRatio.STANDARD_H: (4, 3), + AspectRatio.WIDESCREEN_V: (9, 16), AspectRatio.WIDESCREEN_H: (16, 9), AspectRatio.ULTRAWIDE_H: (21, 9), - AspectRatio.PHOTO_V: (2, 3), - AspectRatio.STANDARD_V: (3, 4), - AspectRatio.WIDESCREEN_V: (9, 16), } @@ -50,26 +50,35 @@ class ResolutionSelector(io.ComfyNode): min=0.1, max=16.0, step=0.1, - tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.", + tooltip="Target total megapixels. 1.0 MP ≈ 1024x1024 for square.", + ), + io.Int.Input( + id="multiple", + default=8, + min=8, + max=128, + step=4, + tooltip="Nearest multiple of the result to set the selected resolution to.", + advanced=True, ), ], outputs=[ io.Int.Output( - "width", tooltip="Calculated width in pixels (multiple of 8)." + "width", tooltip="Calculated width in pixels multiplied by the selected multiple." ), io.Int.Output( - "height", tooltip="Calculated height in pixels (multiple of 8)." + "height", tooltip="Calculated height in pixels multiplied by the selected multiple." ), ], ) @classmethod - def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput: + def execute(cls, aspect_ratio: str, megapixels: float, multiple: int) -> io.NodeOutput: w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] total_pixels = megapixels * 1024 * 1024 scale = math.sqrt(total_pixels / (w_ratio * h_ratio)) - width = round(w_ratio * scale / 8) * 8 - height = round(h_ratio * scale / 8) * 8 + width = round(w_ratio * scale / multiple) * multiple + height = round(h_ratio * scale / multiple) * multiple return io.NodeOutput(width, height) diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index a91549e7f..1b6592bb2 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -337,6 +337,12 @@ class SaveGLB(IO.ComfyNode): IO.File3DFBX, IO.File3DSTL, IO.File3DUSDZ, + IO.File3DPLY, + IO.File3DSPLAT, + IO.File3DSPZ, + IO.File3DKSPLAT, + IO.File3DSplatAny, + IO.File3DPointCloudAny, IO.File3DAny, ], tooltip="Mesh or 3D file to save", diff --git a/comfy_extras/nodes_scail.py b/comfy_extras/nodes_scail.py new file mode 100644 index 000000000..bba0942d7 --- /dev/null +++ b/comfy_extras/nodes_scail.py @@ -0,0 +1,325 @@ +"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3 +preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes.""" + +from typing_extensions import override + +import torch +import torch.nn.functional as F + +import nodes +import node_helpers +import comfy.model_management +import comfy.utils +from comfy_api.latest import ComfyExtension, io +from comfy.ldm.sam3.tracker import unpack_masks + +SAM3TrackData = io.Custom("SAM3_TRACK_DATA") + + +# Model was trained on these exact colors; deviating degrades multi-identity quality. +DEFAULT_PALETTE = [ + (0.0, 0.0, 1.0), # Blue + (1.0, 0.0, 0.0), # Red + (0.0, 1.0, 0.0), # Green + (1.0, 0.0, 1.0), # Magenta + (0.0, 1.0, 1.0), # Cyan + (1.0, 1.0, 0.0), # Yellow +] + + +def _unpack(track_data): + packed = track_data["packed_masks"] + if packed is None or packed.shape[1] == 0: + return None + return unpack_masks(packed) + + +def _first_frame_cx_area(masks_bool): + first = masks_bool[0].float() + H, W = first.shape[-2], first.shape[-1] + n_pixels = H * W + grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W) + area = first.sum(dim=(-1, -2)).clamp_(min=1) + cx = (first * grid_x).sum(dim=(-1, -2)) / area + return (cx / W).tolist(), (area / n_pixels).tolist() + + +def _subset_track_data(track_data, obj_indices): + out = dict(track_data) + packed = track_data["packed_masks"] + if packed is None or not obj_indices: + out["packed_masks"] = None + if "scores" in out: + out["scores"] = [] + return out + out["packed_masks"] = packed[:, obj_indices].contiguous() + scores = track_data.get("scores") + if scores is not None: + out["scores"] = [scores[i] for i in obj_indices if i < len(scores)] + return out + + +def _render_colored_masks(track_data, background="black"): + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + device = comfy.model_management.intermediate_device() + dtype = comfy.model_management.intermediate_dtype() + bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0) + if packed is None or packed.shape[1] == 0: + T = track_data.get("n_frames", 1) if packed is None else packed.shape[0] + out = torch.empty(T, H, W, 3, device=device, dtype=dtype) + out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2] + return out + T, N_obj = packed.shape[0], packed.shape[1] + colors = torch.tensor( + [DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)], + device=device, dtype=dtype, + ) + masks_full = unpack_masks(packed.to(device)).float() + Hm, Wm = masks_full.shape[-2], masks_full.shape[-1] + masks_full = F.interpolate( + masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest" + ).view(T, N_obj, H, W) > 0.5 + any_mask = masks_full.any(dim=1) + obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1) + color_overlay = colors[obj_idx_map] + bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3) + return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay)) + + +def _extract_mask_to_28ch(rgb_video): + """Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent + (1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c) + threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking.""" + T, H, W, _ = rgb_video.shape + _ON_THRESH = 225.0 / 255.0 + mask = rgb_video.movedim(-1, 1).float() + R = (mask[:, 0:1] > _ON_THRESH).float() + G = (mask[:, 1:2] > _ON_THRESH).float() + B = (mask[:, 2:3] > _ON_THRESH).float() + nR, nG, nB = 1 - R, 1 - G, 1 - B + binary_7ch = torch.cat([ + R * G * B, # white + R * nG * nB, # red + nR * G * nB, # green + nR * nG * B, # blue + R * G * nB, # yellow + R * nG * B, # magenta + nR * G * B, # cyan + ], dim=1) + H_lat, W_lat = H, W + for _ in range(3): + H_lat = (H_lat + 1) // 2 + W_lat = (W_lat + 1) // 2 + binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area') + T_latent = (T - 1) // 4 + 1 + padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0) + out = padded.view(T_latent, 28, H_lat, W_lat) + return out.unsqueeze(0) + + +class WanSCAILToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanSCAILToVideo", + category="model/conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), + io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."), + io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."), + io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), + io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."), + io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."), + io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."), + io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."), + io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."), + io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."), + io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."), + io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), + io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, + video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None, + pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + noise_mask = None + + ref_mask_flag = not replacement_mode + positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag}) + negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag}) + + prev_trimmed = None + if previous_frames is not None and previous_frames.shape[0] > 0: + prev_trimmed = previous_frames[-previous_frame_count:] + video_frame_offset -= prev_trimmed.shape[0] + video_frame_offset = max(0, video_frame_offset) + + ref_latent = None + if reference_image is not None: + reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + # Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte + if replacement_mode and reference_image_mask is not None: + rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1) + is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype) + reference_image = reference_image * is_char + ref_latent = vae.encode(reference_image[:, :, :, :3]) + + if ref_latent is not None: + positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) + negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + if pose_video is not None: + if pose_video.shape[0] <= video_frame_offset: + pose_video = None + else: + pose_video = pose_video[video_frame_offset:] + if pose_video_mask is not None: + if pose_video_mask.shape[0] <= video_frame_offset: + pose_video_mask = None + else: + pose_video_mask = pose_video_mask[video_frame_offset:] + + # Truncate pose+mask jointly to the shorter of the two, capped at length. + ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None] + if ts: + T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1 + if pose_video is not None: + pose_video = pose_video[:T_kept] + if pose_video_mask is not None: + pose_video_mask = pose_video_mask[:T_kept] + + if pose_video is not None: + pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength + positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) + + if pose_video_mask is not None: + mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) + driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw) + positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch}) + negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch}) + + if reference_image_mask is not None: + ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw) + zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype) + ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1) + positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch}) + negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch}) + + if prev_trimmed is not None: + pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1) + prev_latent = vae.encode(pf[:, :, :, :3]) + prev_latent_frames = min(prev_latent.shape[2], latent.shape[2]) + latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype) + noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype) + noise_mask[:, :, :prev_latent_frames] = 0.0 + + out_latent = {"samples": latent} + if noise_mask is not None: + out_latent["noise_mask"] = noise_mask + return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length) + + +class SCAIL2ColoredMask(io.ComfyNode): + """Render SAM3 tracks for the driving pose video and (optionally) the reference + image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by` + across both outputs guarantees identity K maps to the same color on both + sides, for multi-person workflow consistency. + reference_image_mask is always rendered black-bg (model convention) + pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SCAIL2ColoredMask", + display_name="Create SCAIL-2 Colored Mask", + category="conditioning/video_models/scail", + inputs=[ + SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."), + SAM3TrackData.Input("ref_track_data", optional=True, + tooltip="SAM3 track of the reference image."), + io.String.Input("object_indices", default="", + tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."), + io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right", + tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."), + io.Boolean.Input("replacement_mode", default=False, + tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). " + "True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."), + ], + outputs=[ + io.Image.Output("pose_video_mask"), + io.Image.Output("reference_image_mask"), + ], + is_experimental=True, + ) + + @classmethod + def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None): + def _prep(td): + masks_bool = _unpack(td) + if sort_by != "none" and masks_bool is not None: + cx, area = _first_frame_cx_area(masks_bool) + if sort_by == "left_to_right": + order = sorted(range(len(cx)), key=lambda i: cx[i]) + else: # "area" + order = sorted(range(len(area)), key=lambda i: -area[i]) + td = _subset_track_data(td, order) + if object_indices.strip(): + indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] + packed = td.get("packed_masks") + n_obj = packed.shape[1] if packed is not None else 0 + indices = [i for i in indices if 0 <= i < n_obj] + td = _subset_track_data(td, indices) + return td + + drv = _prep(driving_track_data) + # Animation: driving=black, ref=white. Replacement: driving=white, ref=black. + mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black") + ref_bg = "black" if replacement_mode else "white" + + if ref_track_data is not None: + ref = _prep(ref_track_data) + reference_image_mask = _render_colored_masks(ref, ref_bg) + else: + H, W = drv["orig_size"] + fill_value = 1.0 if ref_bg == "white" else 0.0 + reference_image_mask = torch.full((1, H, W, 3), fill_value, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) + + return io.NodeOutput(mask_video, reference_image_mask) + + +class SCAILExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + WanSCAILToVideo, + SCAIL2ColoredMask, + ] + + +async def comfy_entrypoint() -> SCAILExtension: + return SCAILExtension() diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 046eeaaf5..bb68da6fa 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -15,6 +15,7 @@ import comfy.sampler_helpers import comfy.sd import comfy.utils import comfy.model_management +from comfy.conds import CONDRegular, CONDList from comfy.cli_args import args, PerformanceFeature import comfy_extras.nodes_custom_sampler import folder_paths @@ -120,6 +121,11 @@ def process_cond_list(d, prefix=""): process_cond_list(v, f"{prefix}.{k}") elif isinstance(v, torch.Tensor): d[k] = v.clone() + elif isinstance(v, CONDList): + v.cond = [t.detach() if isinstance(t, torch.Tensor) else t for t in v.cond] + elif isinstance(v, CONDRegular): + if isinstance(v.cond, torch.Tensor): + v.cond = v.cond.detach() elif isinstance(v, (list, tuple)): for index, item in enumerate(v): process_cond_list(item, f"{prefix}.{k}.{index}") @@ -1143,45 +1149,45 @@ class TrainLoraNode(io.ComfyNode): # Process conditioning positive = _process_conditioning(positive) - # Setup model and dtype - mp = model.clone() - use_grad_scaler = False - lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) - if training_dtype != "none": - dtype = node_helpers.string_to_torch_dtype(training_dtype) - mp.set_model_compute_dtype(dtype) - else: - # Detect model's native dtype for autocast - model_dtype = mp.model.get_dtype() - if model_dtype == torch.float16: - dtype = torch.float16 - # GradScaler only supports float16 gradients, not bfloat16. - # Only enable it when lora params will also be in float16. - if lora_dtype != torch.bfloat16: - use_grad_scaler = True - # Warn about fp16 accumulation instability during training - if PerformanceFeature.Fp16Accumulation in args.fast: - logging.warning( - "WARNING: FP16 model detected with fp16_accumulation enabled. " - "This combination can be numerically unstable during training and may cause NaN values. " - "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)." - ) - else: - # For fp8, bf16, or other dtypes, use bf16 autocast - dtype = torch.bfloat16 - - # Prepare latents and compute counts - latents_dtype = dtype if dtype not in (None,) else torch.bfloat16 - latents, num_images, multi_res = _prepare_latents_and_count( - latents, latents_dtype, bucket_mode - ) - - # Validate and expand conditioning - positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) - with torch.inference_mode(False): + # Setup model and dtype + mp = model.clone(force_deepcopy=True) + use_grad_scaler = False + lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) + if training_dtype != "none": + dtype = node_helpers.string_to_torch_dtype(training_dtype) + mp.set_model_compute_dtype(dtype) + else: + # Detect model's native dtype for autocast + model_dtype = mp.model.get_dtype() + if model_dtype == torch.float16: + dtype = torch.float16 + # GradScaler only supports float16 gradients, not bfloat16. + # Only enable it when lora params will also be in float16. + if lora_dtype != torch.bfloat16: + use_grad_scaler = True + # Warn about fp16 accumulation instability during training + if PerformanceFeature.Fp16Accumulation in args.fast: + logging.warning( + "WARNING: FP16 model detected with fp16_accumulation enabled. " + "This combination can be numerically unstable during training and may cause NaN values. " + "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)." + ) + else: + # For fp8, bf16, or other dtypes, use bf16 autocast + dtype = torch.bfloat16 + + # Prepare latents and compute counts + latents_dtype = dtype if dtype not in (None,) else torch.bfloat16 + latents, num_images, multi_res = _prepare_latents_and_count( + latents, latents_dtype, bucket_mode + ) + + # Validate and expand conditioning + positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode) + # Setup models for training - mp.model.requires_grad_(False) + mp.model.requires_grad_(False).train() # Load existing LoRA weights if provided existing_weights, existing_steps = _load_existing_lora(existing_lora) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index e1d0d8019..514edea20 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -136,6 +136,17 @@ class CreateVideo(io.ComfyNode): io.Image.Input("images", tooltip="The images to create a video from."), io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0), io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."), + io.Int.Input( + "bit_depth", + min=8, + max=10, + default=8, + step=2, + tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less" + " banding, but some players and downstream nodes may not support it.", + optional=True, + display_mode=io.NumberDisplay.number, + ), ], outputs=[ io.Video.Output(), @@ -143,9 +154,14 @@ class CreateVideo(io.ComfyNode): ) @classmethod - def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput: + def execute( + cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8, + ) -> io.NodeOutput: return io.NodeOutput( - InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps))) + InputImpl.VideoFromComponents( + Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)), + bit_depth=bit_depth, + ) ) class GetVideoComponents(io.ComfyNode): @@ -156,7 +172,7 @@ class GetVideoComponents(io.ComfyNode): search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", category="video", - description="Extracts all components from a video: frames, audio, and framerate.", + description="Extracts all components from a video: frames, audio, framerate, and bit depth.", inputs=[ io.Video.Input("video", tooltip="The video to extract components from."), ], @@ -164,13 +180,14 @@ class GetVideoComponents(io.ComfyNode): io.Image.Output(display_name="images"), io.Audio.Output(display_name="audio"), io.Float.Output(display_name="fps"), + io.Int.Output(display_name="bit_depth"), ], ) @classmethod def execute(cls, video: Input.Video) -> io.NodeOutput: components = video.get_components() - return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) + return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth()) class LoadVideo(io.ComfyNode): diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 67d3a8443..d73be8e00 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1456,63 +1456,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode): return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) -class WanSCAILToVideo(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="WanSCAILToVideo", - category="model/conditioning/video_models", - inputs=[ - io.Conditioning.Input("positive"), - io.Conditioning.Input("negative"), - io.Vae.Input("vae"), - io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32), - io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32), - io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), - io.Int.Input("batch_size", default=1, min=1, max=4096), - io.ClipVisionOutput.Input("clip_vision_output", optional=True), - io.Image.Input("reference_image", optional=True), - io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."), - io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."), - io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."), - io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."), - ], - outputs=[ - io.Conditioning.Output(display_name="positive"), - io.Conditioning.Output(display_name="negative"), - io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."), - ], - is_experimental=True, - ) - - @classmethod - def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput: - latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - - ref_latent = None - if reference_image is not None: - reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - ref_latent = vae.encode(reference_image[:, :, :, :3]) - - if ref_latent is not None: - positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True) - negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True) - - if clip_vision_output is not None: - positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) - negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) - - if pose_video is not None: - pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1) - pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength - positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) - negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end) - - out_latent = {} - out_latent["samples"] = latent - return io.NodeOutput(positive, negative, out_latent) - - class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1533,7 +1476,6 @@ class WanExtension(ComfyExtension): WanAnimateToVideo, Wan22ImageToVideoLatent, WanInfiniteTalkToVideo, - WanSCAILToVideo, ] async def comfy_entrypoint() -> WanExtension: diff --git a/cuda_malloc.py b/cuda_malloc.py index f7651981c..8c4422db8 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -2,6 +2,7 @@ import os import importlib.util from comfy.cli_args import args, PerformanceFeature import subprocess +import re #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): @@ -77,11 +78,24 @@ try: except: pass +def get_raw_cuda_version(version_str): + match = re.search(r'\+cu(\d+)', version_str) + if match: + try: + return int(match.group(1)) + except: + pass + return None + if not args.cuda_malloc: try: if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc - args.cuda_malloc = cuda_malloc_supported() + cuda_version = get_raw_cuda_version(version) + if cuda_version is not None and cuda_version >= 130: + args.cuda_malloc = True + else: + args.cuda_malloc = cuda_malloc_supported() except: pass diff --git a/execution.py b/execution.py index 5246d651c..9e16e451d 100644 --- a/execution.py +++ b/execution.py @@ -40,6 +40,7 @@ from comfy_execution.graph_utils import GraphBuilder, is_link from comfy_execution.validation import validate_node_input from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.utils import CurrentNodeContext +from comfy_execution.asset_enrichment import enrich_output_with_assets from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.latest import io, _io from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger @@ -199,6 +200,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None) if io.Hidden.api_key_comfy_org.name in hidden: hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None) + if io.Hidden.comfy_usage_source.name in hidden: + hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get("comfy_usage_source", None) else: if "hidden" in valid_inputs: h = valid_inputs["hidden"] @@ -215,6 +218,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] + if h[x] == "COMFY_USAGE_SOURCE": + input_data_all[x] = [extra_data.get("comfy_usage_source", None)] v3_data["hidden_inputs"] = hidden_inputs_v3 return input_data_all, missing_keys, v3_data @@ -418,6 +423,7 @@ def _is_intermediate_output(dynprompt, node_id): class_def = nodes.NODE_CLASS_MAPPINGS[class_type] return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False) + def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs): if server.client_id is None: return @@ -552,6 +558,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, asyncio.create_task(await_completion()) return (ExecutionResult.PENDING, None, None) if len(output_ui) > 0: + # Enrich at output-processing time (not in the send path) so assets + # are registered even when no client is connected, and the asset id + # flows into ui_outputs and the cache alongside the raw entries. + output_ui = enrich_output_with_assets(output_ui) ui_outputs[unique_id] = { "meta": { "node_id": unique_id, diff --git a/main.py b/main.py index 239a52013..0ad660376 100644 --- a/main.py +++ b/main.py @@ -26,6 +26,7 @@ import utils.extra_config from utils.mime_types import init_mime_types import faulthandler import logging +import signal import sys from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context @@ -37,7 +38,19 @@ if __name__ == "__main__": os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' -faulthandler.enable(file=sys.stderr, all_threads=False) +faulthandler.enable(file=sys.stderr, all_threads=args.debug_hang) +if __name__ == "__main__" and args.debug_hang: + dumping_traceback = False + + def dump_traceback_on_sigint(signum, frame): + global dumping_traceback + if dumping_traceback: + raise KeyboardInterrupt + dumping_traceback = True + faulthandler.dump_traceback(file=sys.stderr, all_threads=True) + raise KeyboardInterrupt + + signal.signal(signal.SIGINT, dump_traceback_on_sigint) import comfy_aimdo.control @@ -477,6 +490,11 @@ def start_comfyui(asyncio_loop=None): init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, init_api_nodes=not args.disable_api_nodes )) + + # Re-apply Comfy's cuDNN benchmark policy after custom-node imports. Benchmark + # mode can request near-card-sized autotune workspaces, and some custom nodes set it at import time. + comfy.model_management.set_cudnn_benchmark() + hook_breaker_ac10a0.restore_functions() cuda_malloc_warning() diff --git a/nodes.py b/nodes.py index 6b7997159..c36aae17d 100644 --- a/nodes.py +++ b/nodes.py @@ -2406,6 +2406,7 @@ async def init_builtin_extra_nodes(): "nodes_video.py", "nodes_lumina2.py", "nodes_wan.py", + "nodes_bernini.py", "nodes_lotus.py", "nodes_hunyuan3d.py", "nodes_primitive.py", @@ -2452,6 +2453,7 @@ async def init_builtin_extra_nodes(): "nodes_rtdetr.py", "nodes_frame_interpolation.py", "nodes_sam3.py", + "nodes_scail.py", "nodes_void.py", "nodes_wandancer.py", "nodes_hidream_o1.py", @@ -2459,7 +2461,8 @@ async def init_builtin_extra_nodes(): "nodes_moge.py", "nodes_mediapipe.py", "nodes_gaussian_splat.py", - "nodes_triposplat.py" + "nodes_triposplat.py", + "nodes_depth_anything_3.py", ] import_failed = [] diff --git a/openapi.yaml b/openapi.yaml index b7e21245f..6e203b1cd 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -3,11 +3,6 @@ components: Asset: description: Represents a user-owned asset (image, video, or other generated output). properties: - asset_hash: - deprecated: true - description: 'Deprecated: use hash instead. Blake3 hash of the asset content.' - pattern: ^blake3:[a-f0-9]{64}$ - type: string created_at: description: Timestamp when the asset was created format: date-time @@ -16,8 +11,12 @@ components: description: Display name of the asset. Mirrors name for backwards compatibility. nullable: true type: string + file_path: + description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors") + nullable: true + type: string hash: - description: Blake3 hash of the asset content. Preferred over asset_hash. + description: Blake3 hash of the asset content. pattern: ^blake3:[a-f0-9]{64}$ type: string id: @@ -139,17 +138,16 @@ components: AssetUpdated: description: Response returned when an existing asset is successfully updated. properties: - asset_hash: - deprecated: true - description: 'Deprecated: use hash instead. Blake3 hash of the asset content.' - pattern: ^blake3:[a-f0-9]{64}$ - type: string display_name: description: Display name of the asset. Mirrors name for backwards compatibility. nullable: true type: string + file_path: + description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors") + nullable: true + type: string hash: - description: Blake3 hash of the asset content. Preferred over asset_hash. + description: Blake3 hash of the asset content. pattern: ^blake3:[a-f0-9]{64}$ type: string id: @@ -828,7 +826,11 @@ components: type: string type: object PaginationInfo: - description: Offset/limit-based pagination metadata included in list responses. + description: | + Pagination metadata included in list responses. Supports both legacy + offset/limit pagination and cursor-based pagination. When cursor-based + pagination is used, `next_cursor` is the primary pagination token and + `offset`/`total` may be zero. properties: has_more: description: Whether more items are available beyond this page @@ -837,12 +839,19 @@ components: description: Items per page minimum: 1 type: integer + next_cursor: + description: | + Opaque cursor for the next page. Pass this value as the `after` + query parameter on the next request. Empty or absent when there + are no more results. + type: string offset: - description: Current offset (0-based) + deprecated: true + description: 'Current offset (0-based). Deprecated: use cursor-based pagination.' minimum: 0 type: integer total: - description: Total number of items matching filters + description: Total number of items matching filters (may be 0 when using cursor pagination) minimum: 0 type: integer required: @@ -887,6 +896,11 @@ components: additionalProperties: true description: The workflow graph to execute type: object + prompt_id: + description: Optional client-supplied job id. Must be a UUID in canonical lowercase hyphenated form; it is echoed back in the response. Omitted or null means the server generates one. + format: uuid + nullable: true + type: string workflow_id: description: UUID identifying the cloud workflow entity to associate with this job type: string @@ -1053,6 +1067,9 @@ components: comfyui_version: description: ComfyUI version type: string + deploy_environment: + description: How this ComfyUI instance is deployed (e.g. cloud, local-git, local-portable, local-desktop) + type: string embedded_python: description: Whether using embedded Python type: boolean @@ -1518,17 +1535,11 @@ paths: schema: default: true type: boolean - - description: Filter assets by exact content hash. Preferred over asset_hash. + - description: Filter assets by exact content hash. in: query name: hash schema: type: string - - deprecated: true - description: 'Deprecated: use hash instead. Filter assets by exact content hash.' - in: query - name: asset_hash - schema: - type: string - description: | Opaque cursor for keyset pagination. Pass the `next_cursor` value from the previous response to fetch the next page. When provided, @@ -1571,42 +1582,12 @@ paths: - file post: description: | - Uploads a new asset to the system with associated metadata. - Supports two upload methods: - 1. Direct file upload (multipart/form-data) - 2. URL-based upload (application/json with source: "url") + Creates a new asset from a direct file upload (multipart/form-data) with associated metadata. If an asset with the same hash already exists, returns the existing asset. - operationId: uploadAsset + operationId: createAsset requestBody: content: - application/json: - schema: - properties: - name: - description: Display name for the asset (used to determine file extension) - type: string - preview_id: - description: Optional preview asset ID - format: uuid - type: string - tags: - description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. - items: - type: string - type: array - url: - description: HTTP/HTTPS URL to download the asset from - format: uri - type: string - user_metadata: - additionalProperties: true - description: Custom metadata to store with the asset - type: object - required: - - url - - name - type: object multipart/form-data: schema: properties: @@ -1614,6 +1595,10 @@ paths: description: The asset file to upload format: binary type: string + hash: + description: Content hash of the file. + pattern: ^(blake3|sha256):[a-f0-9]{64}$ + type: string id: description: Optional asset ID for idempotent creation. If provided and asset exists, returns existing asset. format: uuid @@ -1629,10 +1614,8 @@ paths: format: uuid type: string tags: - description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. - items: - type: string - type: array + description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. + type: string user_metadata: description: Custom JSON metadata as a string type: string @@ -1641,36 +1624,32 @@ paths: type: object required: true responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + description: | + Asset already existed for this user (deduplicated by content hash); the + existing asset is returned with created_new=false. "201": content: application/json: schema: $ref: '#/components/schemas/AssetCreated' - description: Asset created successfully + description: Asset created successfully (created_new=true) "400": content: application/json: schema: $ref: '#/components/schemas/ErrorResponse' - description: Invalid request (bad file, invalid URL, invalid content type, etc.) + description: Invalid request (bad file, invalid content type, etc.) "401": content: application/json: schema: $ref: '#/components/schemas/ErrorResponse' description: Unauthorized - "403": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Source URL requires authentication or access denied - "404": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Source URL not found "413": content: application/json: @@ -1683,19 +1662,13 @@ paths: schema: $ref: '#/components/schemas/ErrorResponse' description: Unsupported media type - "422": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Download failed due to network error or timeout "500": content: application/json: schema: $ref: '#/components/schemas/ErrorResponse' description: Internal server error - summary: Upload a new asset + summary: Create a new asset tags: - file /api/assets/{id}: @@ -1730,7 +1703,7 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' - description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version) + description: 'Asset cannot be deleted because it is referenced by another resource, e.g. a workflow version (error code: ASSET_IN_USE)' "500": content: application/json: @@ -1783,7 +1756,7 @@ paths: description: | Updates an asset's metadata. At least one field must be provided. Only name, mime_type, preview_id, and user_metadata can be updated. - For tag management, use the dedicated PUT /api/assets/{id}/tags endpoint. + For tag management, use POST (add) and DELETE (remove) /api/assets/{id}/tags. operationId: updateAsset parameters: - description: Asset ID @@ -1982,76 +1955,6 @@ paths: summary: Add tags to asset tags: - file - put: - description: Adds and removes tags from an asset in a single operation - operationId: updateAssetTags - parameters: - - description: Asset ID - in: path - name: id - required: true - schema: - format: uuid - type: string - requestBody: - content: - application/json: - schema: - description: At least one of add or remove must contain items. Empty arrays are allowed when the other array has items. - minProperties: 1 - properties: - add: - description: Tags to add to the asset. Can be empty if remove has items. - items: - type: string - type: array - remove: - description: Tags to remove from the asset. Can be empty if add has items. - items: - type: string - type: array - type: object - required: true - responses: - "200": - content: - application/json: - schema: - $ref: '#/components/schemas/TagsModificationResponse' - description: Tags updated successfully - "400": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Invalid request - "401": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Unauthorized - "404": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Asset not found - "422": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Reserved tag validation error - "500": - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - description: Internal server error - summary: Update asset tags - tags: - - file /api/assets/from-hash: post: description: | @@ -2065,8 +1968,8 @@ paths: schema: properties: hash: - description: Hash of the existing asset. Supports Blake3 (blake3:) or SHA256 (sha256:) formats - pattern: ^(blake3|sha256):[a-f0-9]{64}$ + description: 'Blake3 content hash of the existing asset (blake3: prefix)' + pattern: ^blake3:[a-f0-9]{64}$ type: string mime_type: description: MIME type of the asset (e.g., "image/png", "video/mp4") @@ -2090,12 +1993,20 @@ paths: type: object required: true responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + description: | + Asset reference already existed for this user (deduplicated by content + hash); the existing asset is returned with created_new=false. "201": content: application/json: schema: $ref: '#/components/schemas/AssetCreated' - description: Asset reference created successfully + description: Asset reference created successfully (created_new=true) "400": content: application/json: @@ -2887,7 +2798,21 @@ paths: - asc - desc type: string - - description: Pagination offset (0-based) + - description: | + Opaque cursor for keyset pagination. Pass the `next_cursor` value + from a previous response to fetch the next page. + Cursor pagination is supported only when `sort_by=create_time` + (default). If `sort_by=execution_time`, `after` is ignored and + offset/limit pagination is used. + Cursors are opaque base64url payloads — clients should treat them + as strings and not parse the contents. + example: eyJzIjoiY3JlYXRlX3RpbWUiLCJ2IjoiMTcxNjIwMDAwMDAwMDAwMCIsImlkIjoiYTFiMmMzZDQtZTVmNi03YTg5LWIwYzEtZDJlM2Y0YTViNmM3In0 + in: query + name: after + schema: + type: string + - deprecated: true + description: 'Pagination offset (0-based). Deprecated: prefer cursor-based pagination via `after`.' in: query name: offset schema: @@ -2909,6 +2834,12 @@ paths: schema: $ref: '#/components/schemas/JobsListResponse' description: Success - Jobs retrieved + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Bad request (e.g. malformed pagination cursor). "401": content: application/json: diff --git a/requirements.txt b/requirements.txt index 79d38fc06..a49d968af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -comfyui-frontend-package==1.44.19 -comfyui-workflow-templates==0.9.94 -comfyui-embedded-docs==0.5.2 +comfyui-frontend-package==1.45.15 +comfyui-workflow-templates==0.9.98 +comfyui-embedded-docs==0.5.3 torch torchsde torchvision @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=16.0.0 comfy-kitchen==0.2.10 -comfy-aimdo==0.4.8 +comfy-aimdo==0.4.9 requests simpleeval>=1.0.0 blake3 diff --git a/server.py b/server.py index 268441bd1..6b0029adf 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,7 @@ import time import nodes import folder_paths import execution -from comfy_execution.jobs import JobStatus, get_job, get_all_jobs +from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id import uuid import urllib import json @@ -27,6 +27,7 @@ import logging import mimetypes from comfy.cli_args import args +from comfy.deploy_environment import get_deploy_environment import comfy.utils import comfy.model_management from comfy_api import feature_flags @@ -690,6 +691,7 @@ class PromptServer(): "python_version": sys.version, "pytorch_version": comfy.model_management.torch_version, "embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded", + "deploy_environment": get_deploy_environment(), "argv": sys.argv }, "devices": device_entries @@ -942,7 +944,21 @@ class PromptServer(): if "prompt" in json_data: prompt = json_data["prompt"] - prompt_id = str(json_data.get("prompt_id", uuid.uuid4())) + client_prompt_id = json_data.get("prompt_id") + if client_prompt_id is None: + # Absent or explicit null: the server mints the id. + prompt_id = str(uuid.uuid4()) + else: + try: + prompt_id = validate_job_id(client_prompt_id) + except ValueError: + error = { + "type": "invalid_prompt_id", + "message": "prompt_id must be a valid UUID", + "details": "prompt_id must be a UUID string in canonical lowercase hyphenated form; omit it to let the server generate one", + "extra_info": {} + } + return web.json_response({"error": error, "node_errors": {}}, status=400) partial_execution_targets = None if "partial_execution_targets" in json_data: @@ -957,6 +973,11 @@ class PromptServer(): if "client_id" in json_data: extra_data["client_id"] = json_data["client_id"] + + if "comfy_usage_source" not in extra_data: + usage_source = request.headers.get("Comfy-Usage-Source") + if usage_source: + extra_data["comfy_usage_source"] = usage_source if valid[0]: outputs_to_execute = valid[2] sensitive = {} @@ -1253,6 +1274,15 @@ class PromptServer(): if verbose: logging.info("Starting server\n") + if args.debug_hang: + logging.info( + f"{'-' * 80}\n" + "ComfyUI has been started in debug-hang mode. Run your workflow as normal up to\n" + "the point of the hang or freeze, then use ctrl-C in the cmd or controlling\n" + "terminal to dump the python backtraces for debugging. Please attach the extra\n" + "debug info to your bug report.\n" + f"{'-' * 80}" + ) for addr in addresses: address = addr[0] port = addr[1] diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py index 9867b4e14..4aa20372f 100644 --- a/tests-unit/assets_test/conftest.py +++ b/tests-unit/assets_test/conftest.py @@ -6,6 +6,7 @@ import subprocess import sys import tempfile import time +import uuid from pathlib import Path from typing import Callable, Iterator, Optional @@ -188,9 +189,17 @@ def _post_multipart_asset( @pytest.fixture def make_asset_bytes() -> Callable[[str, int], bytes]: + # Salt content per test so it never collides with assets left over from + # earlier tests. Delete is now always a soft delete (content is preserved), + # so the suite can no longer rely on hard-deleting content for isolation. + # Deterministic within a test: the same (name, size) yields the same bytes. + salt = uuid.uuid4().bytes + def _make(name: str, size: int = 8192) -> bytes: seed = sum(ord(c) for c in name) % 251 - return bytes((i * 31 + seed) % 256 for i in range(size)) + body = bytearray((i * 31 + seed) % 256 for i in range(size)) + body[: len(salt)] = salt[:size] + return bytes(body) return _make @@ -212,7 +221,7 @@ def asset_factory(http: requests.Session, api_base: str): for aid in created: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) @pytest.fixture @@ -227,7 +236,11 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas if tags is None: tags = ["models", "checkpoints", "unit-tests", "alpha"] meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} - files = {"file": (name, b"A" * 4096, "application/octet-stream")} + # Unique content per test so the seed always creates a fresh asset (201). + # Delete is now always a soft delete, so content from a prior test survives + # and would otherwise dedup this upload into an existing asset (200). + content = uuid.uuid4().bytes + b"A" * (4096 - 16) + files = {"file": (name, content, "application/octet-stream")} form_data = { "tags": json.dumps(tags), "name": name, @@ -260,4 +273,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str): break for aid in ids: with contextlib.suppress(Exception): - http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30) + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) diff --git a/tests-unit/assets_test/queries/test_asset_reference_keyset.py b/tests-unit/assets_test/queries/test_asset_reference_keyset.py new file mode 100644 index 000000000..d143d60f9 --- /dev/null +++ b/tests-unit/assets_test/queries/test_asset_reference_keyset.py @@ -0,0 +1,112 @@ +"""Keyset-pagination tiebreaker tests for list_references_page. + +When multiple rows share the same primary sort value (e.g. four assets +created in the same microsecond), the secondary `ORDER BY id` is what keeps +keyset pagination from losing or repeating rows. This file exercises that +branch directly against an in-memory SQLite session — engineering identical +timestamps via HTTP is unreliable enough that we work at the query layer. +""" +import uuid +from datetime import datetime + +import pytest +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries.asset_reference import list_references_page + + +def _make_ref(session: Session, created_at: datetime, name: str, owner: str = "") -> AssetReference: + asset = Asset(hash=f"blake3:{uuid.uuid4().hex}", size_bytes=1024) + session.add(asset) + session.flush() + ref = AssetReference( + id=str(uuid.uuid4()), + asset_id=asset.id, + owner_id=owner, + name=name, + file_path=f"/tmp/{name}", + created_at=created_at, + updated_at=created_at, + last_access_time=created_at, + is_missing=False, + ) + session.add(ref) + return ref + + +@pytest.mark.parametrize("order", ["desc", "asc"]) +def test_tiebreaker_walks_duplicate_sort_values(session: Session, order: str): + """Four rows with the SAME created_at must paginate cleanly under cursor + mode — no row dropped, no row repeated, despite the primary sort column + being non-discriminating. + """ + shared_ts = datetime(2024, 5, 20, 12, 0, 0) # naive UTC, like the DB stores + refs = [_make_ref(session, shared_ts, f"tie_{i}.png") for i in range(4)] + session.commit() + + expected_ids = sorted([r.id for r in refs], reverse=(order == "desc")) + + # Walk the cursor by hand: page size 2, take 3 pages (2 + 2 + 0). + seen: list[str] = [] + after_value = None + after_id = None + for _ in range(4): # generous loop bound; ought to be 2 iterations + page, _tag_map, _total = list_references_page( + session, + limit=2, + sort="created_at", + order=order, + after_cursor_value=after_value, + after_cursor_id=after_id, + ) + if not page: + break + seen.extend(p.id for p in page) + # Use the last row's (created_at, id) as the next cursor input. + last = page[-1] + after_value, after_id = last.created_at, last.id + if len(page) < 2: + break + + assert seen == expected_ids, ( + f"keyset tiebreaker failed for order={order}: expected {expected_ids}, got {seen}" + ) + + +def test_tiebreaker_no_duplicates_under_mixed_collisions(session: Session): + """Some rows share a timestamp, some don't. The cursor must still walk + every row exactly once regardless of where ties sit relative to a + page boundary.""" + t1 = datetime(2024, 5, 20, 12, 0, 0) + t2 = datetime(2024, 5, 20, 12, 0, 1) + layout = [t1, t1, t1, t2, t2] # three rows at t1, two at t2 + refs = [_make_ref(session, ts, f"mix_{i}.png") for i, ts in enumerate(layout)] + session.commit() + + all_ids = {r.id for r in refs} + seen_set: set[str] = set() + seen_list: list[str] = [] + after_value = None + after_id = None + for _ in range(6): + page, _, _ = list_references_page( + session, + limit=2, + sort="created_at", + order="desc", + after_cursor_value=after_value, + after_cursor_id=after_id, + ) + if not page: + break + for p in page: + assert p.id not in seen_set, f"duplicate row {p.id} appeared in cursor walk" + seen_set.add(p.id) + seen_list.append(p.id) + last = page[-1] + after_value, after_id = last.created_at, last.id + if len(page) < 2: + break + + assert seen_set == all_ids, f"missing rows: expected {all_ids}, got {seen_set}" diff --git a/tests-unit/assets_test/queries/test_tags.py b/tests-unit/assets_test/queries/test_tags.py index 4ed99aa37..6222714d1 100644 --- a/tests-unit/assets_test/queries/test_tags.py +++ b/tests-unit/assets_test/queries/test_tags.py @@ -40,15 +40,15 @@ def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id class TestEnsureTagsExist: def test_creates_new_tags(self, session: Session): - ensure_tags_exist(session, ["alpha", "beta"], tag_type="user") + ensure_tags_exist(session, ["alpha", "beta"]) session.commit() tags = session.query(Tag).all() assert {t.name for t in tags} == {"alpha", "beta"} def test_is_idempotent(self, session: Session): - ensure_tags_exist(session, ["alpha"], tag_type="user") - ensure_tags_exist(session, ["alpha"], tag_type="user") + ensure_tags_exist(session, ["alpha"]) + ensure_tags_exist(session, ["alpha"]) session.commit() assert session.query(Tag).count() == 1 @@ -65,13 +65,6 @@ class TestEnsureTagsExist: session.commit() assert session.query(Tag).count() == 0 - def test_tag_type_is_set(self, session: Session): - ensure_tags_exist(session, ["system-tag"], tag_type="system") - session.commit() - - tag = session.query(Tag).filter_by(name="system-tag").one() - assert tag.tag_type == "system" - class TestGetReferenceTags: def test_returns_empty_for_no_tags(self, session: Session): @@ -193,7 +186,7 @@ class TestMissingTagFunctions: def test_add_missing_tag_for_asset_id(self, session: Session): asset = _make_asset(session, "hash1") ref = _make_reference(session, asset) - ensure_tags_exist(session, ["missing"], tag_type="system") + ensure_tags_exist(session, ["missing"]) add_missing_tag_for_asset_id(session, asset_id=asset.id) session.commit() @@ -204,7 +197,7 @@ class TestMissingTagFunctions: def test_add_missing_tag_is_idempotent(self, session: Session): asset = _make_asset(session, "hash1") ref = _make_reference(session, asset) - ensure_tags_exist(session, ["missing"], tag_type="system") + ensure_tags_exist(session, ["missing"]) add_missing_tag_for_asset_id(session, asset_id=asset.id) add_missing_tag_for_asset_id(session, asset_id=asset.id) @@ -216,7 +209,7 @@ class TestMissingTagFunctions: def test_remove_missing_tag_for_asset_id(self, session: Session): asset = _make_asset(session, "hash1") ref = _make_reference(session, asset) - ensure_tags_exist(session, ["missing"], tag_type="system") + ensure_tags_exist(session, ["missing"]) add_missing_tag_for_asset_id(session, asset_id=asset.id) remove_missing_tag_for_asset_id(session, asset_id=asset.id) @@ -237,7 +230,7 @@ class TestListTagsWithUsage: rows, total = list_tags_with_usage(session) - tag_dict = {name: count for name, _, count in rows} + tag_dict = {name: count for name, count in rows} assert tag_dict["used"] == 1 assert tag_dict["unused"] == 0 assert total == 2 @@ -252,7 +245,7 @@ class TestListTagsWithUsage: rows, total = list_tags_with_usage(session, include_zero=False) - tag_names = {name for name, _, _ in rows} + tag_names = {name for name, _ in rows} assert "used" in tag_names assert "unused" not in tag_names @@ -262,7 +255,7 @@ class TestListTagsWithUsage: rows, total = list_tags_with_usage(session, prefix="alph") - tag_names = {name for name, _, _ in rows} + tag_names = {name for name, _ in rows} assert tag_names == {"alpha", "alphabet"} def test_order_by_name(self, session: Session): @@ -271,7 +264,7 @@ class TestListTagsWithUsage: rows, _ = list_tags_with_usage(session, order="name_asc") - names = [name for name, _, _ in rows] + names = [name for name, _ in rows] assert names == ["alpha", "middle", "zebra"] def test_owner_visibility(self, session: Session): @@ -287,13 +280,13 @@ class TestListTagsWithUsage: # Empty owner sees only shared rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False) - tag_dict = {name: count for name, _, count in rows} + tag_dict = {name: count for name, count in rows} assert tag_dict.get("shared-tag", 0) == 1 assert tag_dict.get("owner-tag", 0) == 0 # User1 sees both rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False) - tag_dict = {name: count for name, _, count in rows} + tag_dict = {name: count for name, count in rows} assert tag_dict.get("shared-tag", 0) == 1 assert tag_dict.get("owner-tag", 0) == 1 diff --git a/tests-unit/assets_test/services/test_cursor.py b/tests-unit/assets_test/services/test_cursor.py new file mode 100644 index 000000000..47970e168 --- /dev/null +++ b/tests-unit/assets_test/services/test_cursor.py @@ -0,0 +1,278 @@ +"""Tests for app.assets.services.cursor. + +Cursors are opaque tokens internal to this server — these tests cover +round-tripping, validation, and length caps, not any particular wire +byte layout. +""" +from __future__ import annotations + +import base64 +from datetime import datetime, timedelta, timezone + +import pytest + +from app.assets.services.cursor import ( + MAX_CURSOR_ID_LENGTH, + MAX_CURSOR_VALUE_LENGTH, + MAX_ENCODED_CURSOR_LENGTH, + CursorPayload, + InvalidCursorError, + decode_cursor, + decode_cursor_int, + decode_cursor_time, + encode_cursor, + encode_cursor_from_time, +) + + +ALLOWED = ("created_at", "updated_at", "name", "size") + + +class TestRoundTrip: + @pytest.mark.parametrize( + "sort_field, value, id", + [ + ("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"), + ("size", "1024", "asset-123"), + ("name", "my-asset.png", "asset-abc"), + ("name", "résumé.txt", "asset-uni"), + ("name", "foo<&>bar.png", "asset-html"), + ("name", 'quo"te\\back\nnewline.png', "asset-esc"), + ], + ) + def test_encode_decode(self, sort_field, value, id): + encoded = encode_cursor(sort_field, value, id) + assert encoded != "" + payload = decode_cursor(encoded, ALLOWED) + assert payload.sort_field == sort_field + assert payload.value == value + assert payload.id == id + + +class TestTimeCursor: + def test_microsecond_precision_preserved(self): + # Pick a time with non-zero microseconds — encoding at ms would lose the µs. + ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc) + encoded = encode_cursor_from_time("created_at", ts, "id-1") + payload = decode_cursor(encoded, ALLOWED) + # Value must be a microsecond integer string, not a millisecond one. + assert payload.value == "1716209600123456" + decoded = decode_cursor_time(payload) + assert decoded == ts + + def test_decode_returns_utc(self): + payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1", order="desc") + decoded = decode_cursor_time(payload) + assert decoded.tzinfo == timezone.utc + + def test_naive_datetime_rejected_on_encode(self): + naive = datetime(2024, 5, 20, 12, 0, 0) + with pytest.raises(ValueError): + encode_cursor_from_time("created_at", naive, "id-1") + + def test_non_integer_value_rejected_on_decode(self): + with pytest.raises(InvalidCursorError): + decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1", "desc")) + + def test_none_payload_rejected(self): + with pytest.raises(InvalidCursorError): + decode_cursor_time(None) + + def test_non_utc_aware_normalized(self): + # Same instant, different timezone — must encode to the same micros. + utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc) + offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5))) + assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time( + "created_at", offset_ts, "x" + ) + + +class TestIntCursor: + def test_decode_int(self): + assert decode_cursor_int(CursorPayload("size", "1024", "id-1", "desc")) == 1024 + + def test_decode_int_rejects_non_int(self): + with pytest.raises(InvalidCursorError): + decode_cursor_int(CursorPayload("size", "abc", "id-1", "desc")) + + def test_decode_int_rejects_none(self): + with pytest.raises(InvalidCursorError): + decode_cursor_int(None) + + +class TestInvalidInputs: + def test_oversized_cursor(self): + oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1) + with pytest.raises(InvalidCursorError, match="maximum length"): + decode_cursor(oversized, ALLOWED) + + def test_not_base64(self): + with pytest.raises(InvalidCursorError): + decode_cursor("not base64!!!", ALLOWED) + + def test_not_json(self): + encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError): + decode_cursor(encoded, ALLOWED) + + def test_empty_id(self): + # Encoder rejects empty id symmetrically with the decoder, so build the + # payload manually to exercise the decoder's missing-id branch. + raw = b'{"s":"created_at","v":"1","id":"","o":"desc"}' + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="missing id"): + decode_cursor(encoded, ALLOWED) + + def test_oversized_id(self): + # Encoder enforces the cap symmetrically; hand-build to exercise decode. + big_id = "a" * (MAX_CURSOR_ID_LENGTH + 1) + raw = ('{"s":"created_at","v":"1","id":"' + big_id + '","o":"desc"}').encode("ascii") + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="id exceeds maximum length"): + decode_cursor(encoded, ALLOWED) + + def test_oversized_value(self): + # Encoder enforces the cap symmetrically; hand-build to exercise decode. + big_v = "v" * (MAX_CURSOR_VALUE_LENGTH + 1) + raw = ('{"s":"created_at","v":"' + big_v + '","id":"id-1","o":"desc"}').encode("ascii") + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="value exceeds maximum length"): + decode_cursor(encoded, ALLOWED) + + def test_unsupported_sort_field(self): + encoded = encode_cursor("execution_time", "1", "id-1") + with pytest.raises(InvalidCursorError, match="unsupported sort field"): + decode_cursor(encoded, ALLOWED) + + def test_no_allowed_fields_rejects_everything(self): + encoded = encode_cursor("created_at", "1", "id-1") + with pytest.raises(InvalidCursorError): + decode_cursor(encoded, ()) + + def test_non_dict_payload_rejected(self): + encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="expected object"): + decode_cursor(encoded, ALLOWED) + + +class TestEncodeAtCapsFits: + def test_max_field_lengths_fit_wire_cap(self): + # Worst-case payload: value and id at their per-field caps, with a long + # sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH + # so the wire cap cannot reject a cursor the encoder mints at the per-field caps. + value = "v" * MAX_CURSOR_VALUE_LENGTH + id = "i" * MAX_CURSOR_ID_LENGTH + sort_field = "very_long_sort_field_name" + + encoded = encode_cursor(sort_field, value, id) + assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH + payload = decode_cursor(encoded, (sort_field,)) + assert payload.value == value + assert payload.id == id + + +class TestDatetimeOverflow: + """Crafted cursors with extreme micros must map to InvalidCursorError, + not OverflowError/OSError leaking as 500. + """ + + @pytest.mark.parametrize( + "micros_str", + [ + "999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders + "-999999999999999999999", # symmetric negative — pre-epoch overflow + ], + ) + def test_out_of_range_micros_rejected(self, micros_str): + encoded = encode_cursor("created_at", micros_str, "asset-x") + payload = decode_cursor(encoded, ALLOWED) + with pytest.raises(InvalidCursorError): + decode_cursor_time(payload) + + +class TestEncoderDecoderSymmetry: + """The encoder must never mint a cursor the decoder would reject, or the + same server would 400 on a cursor it just handed out. Per-field caps keep + the encoded length below the wire cap, so a freshly minted cursor always + round-trips. + """ + + def test_long_name_within_cap_round_trips(self): + """Assets allow names up to 512 chars (`String(512)`); the cursor + encoder must round-trip a value at that cap so a freshly minted + cursor never fails decode on the next request.""" + long_name = "n" * MAX_CURSOR_VALUE_LENGTH + encoded = encode_cursor("name", long_name, "asset-x") + payload = decode_cursor(encoded, ALLOWED) + assert payload.value == long_name + + def test_encoder_rejects_empty_id(self): + with pytest.raises(InvalidCursorError, match="id must be non-empty"): + encode_cursor("created_at", "1", "") + + def test_encoder_rejects_oversized_id(self): + with pytest.raises(InvalidCursorError, match="id exceeds maximum length"): + encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1)) + + def test_encoder_rejects_oversized_value(self): + with pytest.raises(InvalidCursorError, match="value exceeds maximum length"): + encode_cursor("name", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1") + + def test_multibyte_value_at_cap_round_trips(self): + """A value at the char-count cap made of multibyte characters + (e.g. 'é' = 2 UTF-8 bytes) stays under the wire cap, so it mints and + round-trips — the per-field caps, not a mint-time length check, are + what bound cursor size.""" + value = "é" * MAX_CURSOR_VALUE_LENGTH + encoded = encode_cursor("name", value, "asset-multibyte") + assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH + payload = decode_cursor(encoded, ALLOWED) + assert payload.value == value + + def test_escape_heavy_value_at_cap_round_trips(self): + """JSON escape expansion is the worst case: each control character + serializes to the six-byte `\\uXXXX` form. A value of 512 of them is + the largest a cursor can get, and it still fits the wire cap, mints, + and round-trips.""" + value = "\x01" * MAX_CURSOR_VALUE_LENGTH + encoded = encode_cursor("name", value, "asset-escape") + assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH + payload = decode_cursor(encoded, ALLOWED) + assert payload.value == value + + +class TestOrderBinding: + def test_order_baked_into_payload(self): + encoded = encode_cursor("created_at", "1", "id-1", order="asc") + payload = decode_cursor(encoded, ALLOWED) + assert payload.order == "asc" + + def test_mismatched_order_rejected(self): + encoded = encode_cursor("created_at", "1", "id-1", order="desc") + with pytest.raises(InvalidCursorError, match="does not match request order"): + decode_cursor(encoded, ALLOWED, expected_order="asc") + + def test_matching_order_accepted(self): + encoded = encode_cursor("created_at", "1", "id-1", order="desc") + payload = decode_cursor(encoded, ALLOWED, expected_order="desc") + assert payload.order == "desc" + + def test_invalid_order_token_rejected_on_encode(self): + with pytest.raises(ValueError): + encode_cursor("created_at", "1", "id-1", order="sideways") + + def test_invalid_order_token_rejected_on_decode(self): + # Hand-craft a payload with an illegal `o` value. + raw = b'{"s":"name","v":"x","id":"id-1","o":"sideways"}' + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="unsupported order"): + decode_cursor(encoded, ALLOWED) + + def test_cursor_without_order_rejected(self): + """`o` is mandatory. A cursor minted without it is rejected as + malformed rather than silently walking the keyset in whatever + direction the request happens to ask for.""" + raw = b'{"s":"name","v":"x","id":"id-1"}' + encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + with pytest.raises(InvalidCursorError, match="missing or non-string o"): + decode_cursor(encoded, ALLOWED, expected_order="desc") diff --git a/tests-unit/assets_test/services/test_image_dimensions.py b/tests-unit/assets_test/services/test_image_dimensions.py new file mode 100644 index 000000000..ac275eae2 --- /dev/null +++ b/tests-unit/assets_test/services/test_image_dimensions.py @@ -0,0 +1,86 @@ +"""Tests for the image_dimensions service.""" +from __future__ import annotations + +from pathlib import Path + +import pytest +from PIL import Image + +from app.assets.services.image_dimensions import extract_image_dimensions + + +def _make_png(path: Path, size: tuple[int, int]) -> Path: + img = Image.new("RGB", size, color=(123, 45, 67)) + img.save(path, format="PNG") + return path + + +def _make_jpeg(path: Path, size: tuple[int, int]) -> Path: + img = Image.new("RGB", size, color=(10, 20, 30)) + img.save(path, format="JPEG", quality=80) + return path + + +class TestExtractImageDimensions: + def test_extracts_png_dimensions(self, tmp_path: Path): + f = _make_png(tmp_path / "rect.png", (320, 240)) + + result = extract_image_dimensions(str(f), mime_type="image/png") + + assert result == {"kind": "image", "width": 320, "height": 240} + + def test_extracts_jpeg_dimensions(self, tmp_path: Path): + f = _make_jpeg(tmp_path / "shot.jpg", (1920, 1080)) + + result = extract_image_dimensions(str(f), mime_type="image/jpeg") + + assert result == {"kind": "image", "width": 1920, "height": 1080} + + def test_works_when_mime_type_is_none(self, tmp_path: Path): + f = _make_png(tmp_path / "no_mime.png", (50, 100)) + + result = extract_image_dimensions(str(f), mime_type=None) + + assert result == {"kind": "image", "width": 50, "height": 100} + + def test_skips_non_image_mime_without_touching_file(self, tmp_path: Path): + # Path doesn't need to exist — non-image MIME short-circuits. + result = extract_image_dimensions( + str(tmp_path / "model.safetensors"), + mime_type="application/octet-stream", + ) + + assert result is None + + @pytest.mark.parametrize( + "mime", + ["application/json", "text/plain", "video/mp4", "audio/mpeg"], + ) + def test_skips_all_non_image_mime_types(self, tmp_path: Path, mime: str): + f = tmp_path / "file.bin" + f.write_bytes(b"\x00\x01\x02") + + assert extract_image_dimensions(str(f), mime_type=mime) is None + + def test_returns_none_for_missing_file(self, tmp_path: Path): + result = extract_image_dimensions( + str(tmp_path / "does_not_exist.png"), mime_type="image/png" + ) + + assert result is None + + def test_returns_none_for_corrupt_image(self, tmp_path: Path): + f = tmp_path / "corrupt.png" + f.write_bytes(b"not actually a png file") + + result = extract_image_dimensions(str(f), mime_type="image/png") + + assert result is None + + def test_returns_none_for_empty_file(self, tmp_path: Path): + f = tmp_path / "empty.png" + f.write_bytes(b"") + + result = extract_image_dimensions(str(f), mime_type="image/png") + + assert result is None diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py index b153f9795..12a3bdfe6 100644 --- a/tests-unit/assets_test/services/test_ingest.py +++ b/tests-unit/assets_test/services/test_ingest.py @@ -4,10 +4,12 @@ from pathlib import Path from unittest.mock import patch import pytest +from PIL import Image from sqlalchemy.orm import Session as SASession, Session from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag from app.assets.database.queries import get_reference_tags +from app.assets.helpers import get_utc_now from app.assets.services.ingest import ( _ingest_file_from_path, _register_existing_asset, @@ -15,6 +17,11 @@ from app.assets.services.ingest import ( ) +def _make_png(path: Path, size: tuple[int, int]) -> Path: + Image.new("RGB", size, color=(80, 120, 200)).save(path, format="PNG") + return path + + class TestIngestFileFromPath: def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session): file_path = temp_dir / "test_file.bin" @@ -279,4 +286,203 @@ class TestIngestExistingFileTagFK: ref_tags = sess.query(AssetReferenceTag).all() ref_tag_names = {rt.tag_name for rt in ref_tags} assert "output" in ref_tag_names - assert "my-job" in ref_tag_names + + +class TestIngestImageDimensions: + """system_metadata should carry {kind, width, height} for image assets.""" + + def test_image_asset_emits_dimensions( + self, mock_create_session, temp_dir: Path, session: Session + ): + f = _make_png(temp_dir / "shot.png", (640, 480)) + + result = _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:img1", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000000, + mime_type="image/png", + ) + + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.system_metadata == { + "kind": "image", + "width": 640, + "height": 480, + } + + def test_non_image_asset_leaves_system_metadata_empty( + self, mock_create_session, temp_dir: Path, session: Session + ): + f = temp_dir / "model.safetensors" + f.write_bytes(b"not an image") + + result = _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:safetensors1", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000000, + mime_type="application/octet-stream", + ) + + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + assert ref.system_metadata in (None, {}) + + def test_preserves_existing_system_metadata_keys( + self, mock_create_session, temp_dir: Path, session: Session + ): + f = _make_png(temp_dir / "annotated.png", (100, 200)) + + # First pass populates a sentinel system_metadata key (simulating prior + # enricher write). + result = _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:img-merge", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000000, + mime_type="image/png", + ) + ref = session.query(AssetReference).filter_by(id=result.reference_id).first() + ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/x.png"} + session.commit() + + # Second pass with the same path triggers the merge code path again. + _ingest_file_from_path( + abs_path=str(f), + asset_hash="blake3:img-merge", + size_bytes=f.stat().st_size, + mtime_ns=1234567890000000001, + mime_type="image/png", + ) + + session.refresh(ref) + assert ref.system_metadata["kind"] == "image" + assert ref.system_metadata["width"] == 100 + assert ref.system_metadata["height"] == 200 + assert ref.system_metadata["source_url"] == "https://example/x.png" + + +class TestRegisterExistingAssetBackfill: + """The from-hash path back-fills dimensions from a sibling reference.""" + + def _add_reference( + self, + session: Session, + asset: Asset, + name: str, + system_metadata: dict | None = None, + ) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + asset_id=asset.id, + name=name, + owner_id="", + created_at=now, + updated_at=now, + last_access_time=now, + system_metadata=system_metadata or {}, + ) + session.add(ref) + session.flush() + return ref + + def test_backfills_dimensions_from_sibling_image_reference( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:shared", size_bytes=2048, mime_type="image/png") + session.add(asset) + session.flush() + self._add_reference( + session, + asset, + name="original.png", + system_metadata={"kind": "image", "width": 800, "height": 600}, + ) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:shared", + name="from_hash.png", + owner_id="user-x", + ) + + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + assert ref.system_metadata.get("kind") == "image" + assert ref.system_metadata.get("width") == 800 + assert ref.system_metadata.get("height") == 600 + + def test_no_backfill_when_sibling_has_no_image_metadata( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:nodims", size_bytes=2048, mime_type="image/png") + session.add(asset) + session.flush() + self._add_reference( + session, + asset, + name="original.png", + system_metadata={"base_model": "flux"}, # no kind=image + ) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:nodims", + name="from_hash.png", + owner_id="user-x", + ) + + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + meta = ref.system_metadata or {} + assert "kind" not in meta + assert "width" not in meta + assert "height" not in meta + + def test_no_backfill_when_no_sibling_exists( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:lonely", size_bytes=1024, mime_type="image/png") + session.add(asset) + session.commit() + + result = _register_existing_asset( + asset_hash="blake3:lonely", + name="solo.png", + owner_id="user-x", + ) + + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + assert ref.system_metadata in (None, {}) + + def test_backfill_preserves_caller_supplied_keys( + self, mock_create_session, session: Session + ): + asset = Asset(hash="blake3:preserve", size_bytes=2048, mime_type="image/png") + session.add(asset) + session.flush() + self._add_reference( + session, + asset, + name="original.png", + system_metadata={"kind": "image", "width": 1024, "height": 768}, + ) + session.commit() + + # Simulate a from-hash path where the new reference already carries + # some system_metadata (e.g. a download-provenance source_url written + # by an earlier step). The back-fill must merge dim keys without + # clobbering existing keys. + result = _register_existing_asset( + asset_hash="blake3:preserve", + name="from_hash.png", + owner_id="user-x", + ) + ref = session.query(AssetReference).filter_by(id=result.ref.id).first() + # Seed a sentinel key and re-run back-fill via a second register call + # to exercise the merge path with pre-existing data. + ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/p"} + session.commit() + + assert ref.system_metadata.get("source_url") == "https://example/p" + assert ref.system_metadata.get("kind") == "image" + assert ref.system_metadata.get("width") == 1024 + assert ref.system_metadata.get("height") == 768 diff --git a/tests-unit/assets_test/services/test_tagging.py b/tests-unit/assets_test/services/test_tagging.py index ab69e5dc1..fa121db3e 100644 --- a/tests-unit/assets_test/services/test_tagging.py +++ b/tests-unit/assets_test/services/test_tagging.py @@ -141,7 +141,7 @@ class TestListTags: rows, total = list_tags() - tag_dict = {name: count for name, _, count in rows} + tag_dict = {name: count for name, count in rows} assert tag_dict["used"] == 1 assert tag_dict["unused"] == 0 assert total == 2 @@ -155,7 +155,7 @@ class TestListTags: rows, total = list_tags(include_zero=False) - tag_names = {name for name, _, _ in rows} + tag_names = {name for name, _ in rows} assert "used" in tag_names assert "unused" not in tag_names @@ -165,7 +165,7 @@ class TestListTags: rows, _ = list_tags(prefix="alph") - tag_names = {name for name, _, _ in rows} + tag_names = {name for name, _ in rows} assert tag_names == {"alpha", "alphabet"} def test_order_by_name(self, mock_create_session, session: Session): @@ -174,7 +174,7 @@ class TestListTags: rows, _ = list_tags(order="name_asc") - names = [name for name, _, _ in rows] + names = [name for name, _ in rows] assert names == ["alpha", "middle", "zebra"] def test_pagination(self, mock_create_session, session: Session): @@ -185,7 +185,7 @@ class TestListTags: assert total == 5 assert len(rows) == 2 - names = [name for name, _, _ in rows] + names = [name for name, _ in rows] assert names == ["b", "c"] def test_clamps_limit(self, mock_create_session, session: Session): diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py index fd2e9a098..36abb60ee 100644 --- a/tests-unit/assets_test/test_crud.py +++ b/tests-unit/assets_test/test_crud.py @@ -45,8 +45,8 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse assert "user_metadata" in detail assert "filename" in detail["user_metadata"] - # DELETE (hard delete to also remove underlying asset and file) - rd = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + # Soft delete — the reference is hidden, content is preserved + rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) assert rd.status_code == 204 # GET again -> 404 @@ -60,7 +60,7 @@ def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seede aid = seeded_asset["id"] asset_hash = seeded_asset["asset_hash"] - # Soft-delete (default, no delete_content param) + # Soft delete — the reference is hidden, content is preserved rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) assert rd.status_code == 204 @@ -81,11 +81,10 @@ def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seede ids = [a["id"] for a in rl.json().get("assets", [])] assert aid not in ids - # Clean up: hard-delete the soft-deleted reference and orphaned asset - http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + # The reference is already soft-deleted; content is preserved. -def test_delete_upon_reference_count( +def test_soft_delete_preserves_asset_identity_across_references( http: requests.Session, api_base: str, seeded_asset: dict ): # Create a second reference to the same asset via from-hash @@ -119,16 +118,20 @@ def test_delete_upon_reference_count( rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) assert rh2.status_code == 200 # asset identity preserved (soft delete) - # Re-associate via from-hash, then hard-delete -> orphan content removed + # Re-associate via from-hash: it must reuse the same preserved content + # (created_new False AND the same hash), proving the soft deletes did not + # destroy the underlying asset. Then soft-delete again -> still preserved. r3 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) assert r3.status_code == 201, r3.json() + assert r3.json()["created_new"] is False + assert r3.json()["asset_hash"] == src_hash # reused the surviving content aid3 = r3.json()["id"] - rd3 = http.delete(f"{api_base}/api/assets/{aid3}?delete_content=true", timeout=120) + rd3 = http.delete(f"{api_base}/api/assets/{aid3}", timeout=120) assert rd3.status_code == 204 rh3 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) - assert rh3.status_code == 404 # orphan content removed + assert rh3.status_code == 200 # content preserved (soft delete) def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict): @@ -249,7 +252,7 @@ def test_concurrent_delete_same_asset_info_single_204( # Hit the same endpoint N times in parallel. n_tests = 4 - url = f"{api_base}/api/assets/{aid}?delete_content=false" + url = f"{api_base}/api/assets/{aid}" def _do_delete(delete_url): with requests.Session() as s: diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py index 672ba9728..42c64a5fd 100644 --- a/tests-unit/assets_test/test_downloads.py +++ b/tests-unit/assets_test/test_downloads.py @@ -117,7 +117,7 @@ def test_download_missing_file_returns_404( assert body["error"]["code"] == "FILE_NOT_FOUND" finally: # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. - dr = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120) + dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) dr.content diff --git a/tests-unit/assets_test/test_list_cursor.py b/tests-unit/assets_test/test_list_cursor.py new file mode 100644 index 000000000..a37019fd6 --- /dev/null +++ b/tests-unit/assets_test/test_list_cursor.py @@ -0,0 +1,349 @@ +"""Integration tests for cursor-based pagination on GET /api/assets. + +These tests exercise the handler/service/query path end-to-end; +cursor-encoding-level tests live in +tests-unit/assets_test/services/test_cursor.py. +""" +import pytest +import requests + + +def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]: + names = [f"cursor_{i:02d}.safetensors" for i in range(count)] + for n in names: + asset_factory( + n, + ["models", "checkpoints", "unit-tests", tag], + {}, + make_asset_bytes(n, size=2048), + ) + return sorted(names) + + +def test_cursor_pages_all_items_in_order(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + names = _seed(asset_factory, make_asset_bytes, count=5, tag="cursor-walk") + + params = { + "include_tags": "unit-tests,cursor-walk", + "sort": "name", + "order": "asc", + "limit": "2", + } + + seen: list[str] = [] + after: str | None = None + pages = 0 + while True: + page_params = dict(params) + if after is not None: + page_params["after"] = after + r = http.get(api_base + "/api/assets", params=page_params, timeout=120) + assert r.status_code == 200, r.text + body = r.json() + seen.extend(a["name"] for a in body["assets"]) + pages += 1 + after = body.get("next_cursor") + if after is None: + break + assert body["has_more"] is True + assert pages < 10, "guard against runaway cursor loop" + + assert seen == names, f"expected {names}, got {seen}" + # Last page should have has_more False + assert body["has_more"] is False + assert "next_cursor" not in body + + +def test_cursor_invalid_returns_400(http: requests.Session, api_base: str): + r = http.get( + api_base + "/api/assets", + params={"after": "not-a-real-cursor", "sort": "created_at"}, + timeout=120, + ) + assert r.status_code == 400, r.text + body = r.json() + assert body["error"]["code"] == "INVALID_CURSOR" + + +def test_cursor_sort_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + _seed(asset_factory, make_asset_bytes, count=2, tag="cursor-mismatch") + + # Take a real cursor minted for sort=name. + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-mismatch", + "sort": "name", + "order": "asc", + "limit": "1", + }, + timeout=120, + ) + assert r.status_code == 200 + cursor = r.json()["next_cursor"] + assert cursor is not None + + # Replay against sort=created_at — should fail with INVALID_CURSOR. + r2 = http.get( + api_base + "/api/assets", + params={"after": cursor, "sort": "created_at"}, + timeout=120, + ) + assert r2.status_code == 400, r2.text + assert r2.json()["error"]["code"] == "INVALID_CURSOR" + + +def test_cursor_wins_over_offset(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-vs-offset") + + # Take a cursor that points past the first item. + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-vs-offset", + "sort": "name", + "order": "asc", + "limit": "1", + }, + timeout=120, + ) + assert r.status_code == 200, r.text + cursor = r.json()["next_cursor"] + assert cursor is not None + + # Pass both 'after' and a large offset. Cursor must win; offset is ignored. + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-vs-offset", + "sort": "name", + "order": "asc", + "limit": "1", + "after": cursor, + "offset": "999", + }, + timeout=120, + ) + assert r2.status_code == 200 + body = r2.json() + # Should land on the second name in sorted order — not skip ahead by 999. + assert [a["name"] for a in body["assets"]] == [names[1]] + + +def test_next_cursor_absent_when_no_more_results(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + _seed(asset_factory, make_asset_bytes, count=2, tag="cursor-exhaust") + + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-exhaust", + "sort": "name", + "order": "asc", + "limit": "50", + }, + timeout=120, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["has_more"] is False + assert "next_cursor" not in body + + +def test_cursor_pagination_first_page_mints_cursor(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + """First-page request (no `after`) must still return `next_cursor` when + more rows exist, or pagination is unreachable from a cold start. + """ + _seed(asset_factory, make_asset_bytes, count=3, tag="cursor-first-page") + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,cursor-first-page", "sort": "name", "order": "asc", "limit": "2"}, + timeout=120, + ) + assert r.status_code == 200, r.text + body = r.json() + assert body["has_more"] is True + assert body.get("next_cursor"), "first page must mint a cursor when more rows exist" + + +def test_cursor_no_spurious_cursor_when_page_size_equals_remainder(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + """When `total` is an exact multiple of `limit`, the final page must + NOT carry a next_cursor — there is nothing past it. + """ + _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-exact-multiple") + # Page 1 + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2"}, + timeout=120, + ) + assert r.status_code == 200, r.text + cursor = r.json()["next_cursor"] + assert cursor is not None + # Page 2 — should exhaust the set with no cursor for a phantom page 3 + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2", "after": cursor}, + timeout=120, + ) + assert r2.status_code == 200, r2.text + body = r2.json() + assert len(body["assets"]) == 2 + assert body["has_more"] is False + assert "next_cursor" not in body + + +@pytest.mark.parametrize("sort_field", ["created_at", "updated_at", "size"]) +def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + """Cursor pagination must work for every sort field the contract claims. + + Without this, the `created_at` / `updated_at` (time-encoded micros) and + `size` (int-encoded) cursor paths go entirely unexercised end-to-end. + """ + # Sizes increase strictly by index, so `size desc` has a deterministic + # expected order. Time-based sorts (created_at / updated_at) can tie when + # rows are inserted faster than the DB's timestamp resolution; for those + # we check coverage and no-duplicates and let the keyset tiebreaker do + # the rest, instead of sleeping between inserts and asserting an order + # that depends on clock granularity. + names = [] + for i in range(4): + n = f"cursor_{sort_field}_{i:02d}.safetensors" + asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i)) + names.append(n) + + params = { + "include_tags": f"unit-tests,cursor-{sort_field}", + "sort": sort_field, + "order": "desc", + "limit": "2", + } + seen: list[str] = [] + after: str | None = None + pages = 0 + while True: + page_params = dict(params) + if after is not None: + page_params["after"] = after + r = http.get(api_base + "/api/assets", params=page_params, timeout=120) + assert r.status_code == 200, r.text + body = r.json() + seen.extend(a["name"] for a in body["assets"]) + after = body.get("next_cursor") + pages += 1 + if after is None: + break + assert pages < 10, "guard against runaway cursor loop" + + # No duplicates: a faulty keyset boundary that returns the same row across + # two pages must fail this check. + assert len(seen) == len(set(seen)), ( + f"cursor walk repeated rows for sort={sort_field}: {seen}" + ) + # Full coverage: every seeded asset reached exactly once. + assert set(seen) == set(names), ( + f"missing items for sort={sort_field}: expected {set(names)}, got {set(seen)}" + ) + # Strict order check for the only field with a clock-independent ordering. + if sort_field == "size": + assert seen == list(reversed(names)), ( + f"size cursor walked out of order: got {seen}, expected {list(reversed(names))}" + ) + + +def test_cursor_order_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + """A cursor minted under desc order replayed against asc must 400, not + silently walk the wrong direction.""" + _seed(asset_factory, make_asset_bytes, count=3, tag="cursor-order-flip") + + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-order-flip", + "sort": "name", + "order": "desc", + "limit": "1", + }, + timeout=120, + ) + assert r.status_code == 200, r.text + cursor = r.json()["next_cursor"] + assert cursor is not None + + # Replay with order flipped to asc — server must reject the cursor. + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-order-flip", + "sort": "name", + "order": "asc", + "limit": "1", + "after": cursor, + }, + timeout=120, + ) + assert r2.status_code == 400, r2.text + assert r2.json()["error"]["code"] == "INVALID_CURSOR" + + +def test_cursor_invalid_cursor_at_microsecond_boundary(http: requests.Session, api_base: str): + """A cursor carrying an out-of-range microsecond timestamp must map to + 400 INVALID_CURSOR, not 500.""" + import base64 + import json + # 10^18 microseconds ≈ year 33658, well past datetime.MAX_YEAR. + # `o` and `order=` must be set; otherwise decode fails earlier on the + # missing-order branch and the µs-overflow path is never exercised. + payload = {"s": "created_at", "o": "desc", "v": "999999999999999999999", "id": "asset-x"} + raw = json.dumps(payload, separators=(",", ":")).encode("utf-8") + cursor = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii") + r = http.get( + api_base + "/api/assets", + params={"after": cursor, "sort": "created_at", "order": "desc"}, + timeout=120, + ) + assert r.status_code == 400, r.text + assert r.json()["error"]["code"] == "INVALID_CURSOR" + + +def test_cursor_pagination_stable_after_delete(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-delete") + + # Page 1. + r = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-delete", + "sort": "name", + "order": "asc", + "limit": "2", + }, + timeout=120, + ) + assert r.status_code == 200 + body = r.json() + page1_names = [a["name"] for a in body["assets"]] + cursor = body["next_cursor"] + assert cursor is not None + assert page1_names == names[:2] + + # Delete an item from page 1 (already returned) — cursor should still + # locate the next page from where it was minted, not re-index. + target_id = body["assets"][0]["id"] + d = http.delete(api_base + f"/api/assets/{target_id}", timeout=120) + assert d.status_code in (200, 204), d.text + + # Page 2 via cursor. + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,cursor-delete", + "sort": "name", + "order": "asc", + "limit": "2", + "after": cursor, + }, + timeout=120, + ) + assert r2.status_code == 200, r2.text + body2 = r2.json() + assert [a["name"] for a in body2["assets"]] == names[2:] diff --git a/tests-unit/assets_test/test_prompt_id_enforcement.py b/tests-unit/assets_test/test_prompt_id_enforcement.py new file mode 100644 index 000000000..86a755c9f --- /dev/null +++ b/tests-unit/assets_test/test_prompt_id_enforcement.py @@ -0,0 +1,69 @@ +"""POST /prompt enforces canonical-UUID job ids at creation time. + +Lives in assets_test because it uses this suite's booted-server fixture. The +invariant itself is pipeline-wide: a job id is stored and compared verbatim +downstream — history keys, websocket correlation, and /interrupt matching — +so a job minted with a non-canonical id would miss every exact-match lookup. + +The prompt bodies here are intentionally invalid workflows — prompt_id +validation happens before workflow validation, so a rejected id returns +``invalid_prompt_id`` while an accepted id falls through to the ordinary +workflow-validation error (proving it cleared the id check). +""" +import requests + + +def _post_prompt(http: requests.Session, api_base: str, body: dict) -> requests.Response: + return http.post(api_base + "/prompt", json=body, timeout=30) + + +def _error_type(r: requests.Response) -> str: + return r.json()["error"]["type"] + + +def test_non_uuid_prompt_id_rejected(http: requests.Session, api_base: str): + r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": "not-a-uuid"}) + assert r.status_code == 400, r.text + assert _error_type(r) == "invalid_prompt_id" + + +def test_non_string_prompt_id_rejected(http: requests.Session, api_base: str): + # Previously str()-coerced (123 became the job id "123"); must now be a 400, + # not a 500 from uuid.UUID choking on a non-string. + r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": 123}) + assert r.status_code == 400, r.text + assert _error_type(r) == "invalid_prompt_id" + + +def test_non_canonical_uuid_rejected(http: requests.Session, api_base: str): + # Parseable as a UUID, but not the canonical lowercase form: rejected + # loudly rather than silently rewritten (downstream lookups match the + # stored id exactly). + r = _post_prompt( + http, + api_base, + {"prompt": {}, "prompt_id": "AAAAAAAA-BBBB-4CCC-8DDD-EEEEEEEEEEEE"}, + ) + assert r.status_code == 400, r.text + assert _error_type(r) == "invalid_prompt_id" + + +def test_canonical_uuid_accepted(http: requests.Session, api_base: str): + # The id clears validation; the empty workflow then fails ordinary prompt + # validation, proving the request got past the id check. + r = _post_prompt( + http, + api_base, + {"prompt": {}, "prompt_id": "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"}, + ) + assert r.status_code == 400, r.text + assert _error_type(r) != "invalid_prompt_id" + + +def test_null_prompt_id_not_rejected(http: requests.Session, api_base: str): + # Explicit null means "server generates" and must not be rejected as an + # invalid id. (The minted id itself is not observable here because the + # workflow is invalid; unit tests cover validate_job_id directly.) + r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": None}) + assert r.status_code == 400, r.text + assert _error_type(r) != "invalid_prompt_id" diff --git a/tests-unit/assets_test/test_sync_references.py b/tests-unit/assets_test/test_sync_references.py index 94cc255bc..2e85076e0 100644 --- a/tests-unit/assets_test/test_sync_references.py +++ b/tests-unit/assets_test/test_sync_references.py @@ -95,7 +95,7 @@ def _make_asset( def _ensure_missing_tag(session: Session): """Ensure the 'missing' tag exists.""" if not session.get(Tag, "missing"): - session.add(Tag(name="missing", tag_type="system")) + session.add(Tag(name="missing")) session.flush() diff --git a/tests-unit/assets_test/test_tags_api.py b/tests-unit/assets_test/test_tags_api.py index 595bf29c6..9729b7d03 100644 --- a/tests-unit/assets_test/test_tags_api.py +++ b/tests-unit/assets_test/test_tags_api.py @@ -69,8 +69,8 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, used_names = [t["name"] for t in body2["tags"]] assert custom_tag in used_names - # Hard-delete the asset so the tag usage drops to zero - rd = http.delete(f"{api_base}/api/assets/{_asset['id']}?delete_content=true", timeout=120) + # Delete the asset reference so the tag usage drops to zero + rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120) assert rd.status_code == 204 # Now the custom tag must not be returned when include_zero=false diff --git a/tests-unit/comfy_api_test/video_bit_depth_test.py b/tests-unit/comfy_api_test/video_bit_depth_test.py new file mode 100644 index 000000000..6c7bc9163 --- /dev/null +++ b/tests-unit/comfy_api_test/video_bit_depth_test.py @@ -0,0 +1,93 @@ +import pytest +import torch +import av +import numpy as np +from fractions import Fraction +from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents +from comfy_api.latest._util.video_types import VideoComponents + + +@pytest.fixture(scope="module") +def gradient_components(): + """Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth""" + width, height, frames = 64, 64, 3 + ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3) + return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30)) + + +@pytest.fixture(scope="module") +def src8(gradient_components, tmp_path_factory): + """8-bit h264 mp4 (Create Video default)""" + path = str(tmp_path_factory.mktemp("video") / "src8.mp4") + VideoFromComponents(gradient_components).save_to(path) + return path + + +@pytest.fixture(scope="module") +def src10(gradient_components, tmp_path_factory): + """10-bit h264 mp4 (Create Video with bit_depth=10)""" + path = str(tmp_path_factory.mktemp("video") / "src10.mp4") + VideoFromComponents(gradient_components, bit_depth=10).save_to(path) + return path + + +def probe(path): + """(codec, pix_fmt, bit_depth) of the first video stream""" + with av.open(path) as container: + stream = container.streams.video[0] + return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components)) + + +def decoded_levels(path): + """Unique tonal levels in the first decoded frame (banding measure)""" + with av.open(path) as container: + frame = next(container.decode(container.streams.video[0])) + return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0])) + + +def video_packet_bytes(path): + """Raw video packet payloads; identical to the source's only for a true remux""" + with av.open(path) as container: + return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size] + + +def test_create_video_bit_depth(src8, src10): + """Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding""" + assert probe(src8) == ("h264", "yuv420p", 8) + assert probe(src10) == ("h264", "yuv420p10le", 10) + assert decoded_levels(src10) > 2 * decoded_levels(src8) + + +def test_save_auto_keeps_source_depth(src8, src10, tmp_path): + """Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte""" + for name, src in [("p8", src8), ("p10", src10)]: + path = str(tmp_path / f"{name}.mp4") + VideoFromFile(src).save_to(path) + assert probe(path) == probe(src) + assert video_packet_bytes(path) == video_packet_bytes(src) + + +def test_save_explicit_depth_reencodes(src8, src10, tmp_path): + """An explicit bit_depth different from the source forces a re-encode to that depth""" + down = str(tmp_path / "down8.mp4") + VideoFromFile(src10).save_to(down, bit_depth=8) + assert probe(down) == ("h264", "yuv420p", 8) + + up = str(tmp_path / "up10.mp4") + VideoFromFile(src8).save_to(up, bit_depth=10) + assert probe(up) == ("h264", "yuv420p10le", 10) + + +def test_trim_keeps_source_depth(src10, tmp_path): + """Video Slice re-encodes (trim) but preserves the source's 10-bit depth""" + path = str(tmp_path / "trim.mp4") + VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path) + assert probe(path) == ("h264", "yuv420p10le", 10) + + +def test_get_bit_depth(gradient_components, src8, src10): + """get_bit_depth reports a video's depth (backs the Get Video Components output)""" + assert VideoFromFile(src8).get_bit_depth() == 8 + assert VideoFromFile(src10).get_bit_depth() == 10 + assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10 + assert VideoFromComponents(gradient_components).get_bit_depth() == 8 diff --git a/tests-unit/execution_test/test_enrich_output.py b/tests-unit/execution_test/test_enrich_output.py new file mode 100644 index 000000000..61490c49e --- /dev/null +++ b/tests-unit/execution_test/test_enrich_output.py @@ -0,0 +1,205 @@ +"""Tests for enrich_output_with_assets in comfy_execution/asset_enrichment.py.""" +import os +import types +import unittest +from unittest.mock import MagicMock, patch + + +def _make_args(enable_assets: bool): + a = types.SimpleNamespace() + a.enable_assets = enable_assets + return a + + +def _make_register_result(ref_id="ref-id-2"): + result = MagicMock() + result.ref.id = ref_id + return result + + +# Platform-appropriate absolute base. tempfile.gettempdir() returns C:\... on +# Windows and /tmp on POSIX, so containment via commonpath behaves naturally. +_DEFAULT_BASE = os.path.join(__import__("tempfile").gettempdir(), "asset-enrichment-test-base") + + +def _mocked_modules(*, enable_assets=True, register_file_in_place=None, directory=_DEFAULT_BASE): + return { + "comfy.cli_args": MagicMock(args=_make_args(enable_assets)), + "folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=directory)), + "app.assets.services.ingest": MagicMock( + register_file_in_place=register_file_in_place or MagicMock(return_value=_make_register_result()), + DependencyMissingError=type("DependencyMissingError", (Exception,), {}), + ), + } + + +def _call(output_ui, *, enable_assets=True, file_exists=True, register_result=None, directory=_DEFAULT_BASE): + register_mock = MagicMock(return_value=register_result or _make_register_result()) + mocked = _mocked_modules( + enable_assets=enable_assets, + register_file_in_place=register_mock, + directory=directory, + ) + + # Only os.path.isfile is patched — abspath/join must run natively so the + # containment check sees real platform paths. + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=file_exists): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + return mod.enrich_output_with_assets(output_ui) + + +class TestEnrichOutputWithAssets(unittest.TestCase): + + def test_disabled_returns_unchanged(self): + output = {"images": [{"filename": "a.png", "subfolder": "", "type": "output"}]} + result = _call(output, enable_assets=False) + self.assertNotIn("id", result["images"][0]) + + def test_non_list_value_passed_through(self): + output = {"text": "hello"} + result = _call(output) + self.assertEqual(result["text"], "hello") + + def test_entry_without_filename_unchanged(self): + output = {"latent": [{"subfolder": "", "type": "output"}]} + result = _call(output) + self.assertNotIn("id", result["latent"][0]) + + def test_entry_without_type_unchanged(self): + output = {"data": [{"filename": "a.png", "subfolder": ""}]} + result = _call(output) + self.assertNotIn("id", result["data"][0]) + + def test_file_not_on_disk_unchanged(self): + output = {"images": [{"filename": "missing.png", "subfolder": "", "type": "output"}]} + result = _call(output, file_exists=False) + self.assertNotIn("id", result["images"][0]) + + def test_unknown_type_returns_none_directory_unchanged(self): + output = {"images": [{"filename": "a.png", "subfolder": "", "type": "unknown"}]} + result = _call(output, directory=None) + self.assertNotIn("id", result["images"][0]) + + def test_register_injects_only_id(self): + reg = _make_register_result(ref_id="inline-ref") + output = {"images": [{"filename": "new.png", "subfolder": "", "type": "output"}]} + result = _call(output, register_result=reg) + img = result["images"][0] + self.assertEqual(img["id"], "inline-ref") + # Only id is injected — no asset_hash, name, preview_url, size + self.assertNotIn("asset_hash", img) + self.assertNotIn("name", img) + self.assertNotIn("preview_url", img) + self.assertNotIn("size", img) + + def test_register_called_per_entry(self): + register_mock = MagicMock(return_value=_make_register_result()) + mocked = _mocked_modules(register_file_in_place=register_mock) + output = { + "images": [ + {"filename": "a.png", "subfolder": "", "type": "output"}, + {"filename": "b.png", "subfolder": "", "type": "output"}, + ] + } + + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=True): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + mod.enrich_output_with_assets(output) + + self.assertEqual(register_mock.call_count, 2) + + def test_original_entry_not_mutated(self): + orig = {"filename": "a.png", "subfolder": "", "type": "output"} + output = {"images": [orig]} + _call(output) + self.assertNotIn("id", orig) + + def test_enrichment_error_does_not_block_sibling_entries(self): + call_count = [0] + good_reg = _make_register_result(ref_id="good-ref") + + def register_side_effect(abs_path, name, tags): + call_count[0] += 1 + if call_count[0] == 1: + raise RuntimeError("boom") + return good_reg + + mocked = _mocked_modules(register_file_in_place=register_side_effect) + + output = { + "images": [ + {"filename": "bad.png", "subfolder": "", "type": "output"}, + {"filename": "good.png", "subfolder": "", "type": "output"}, + ] + } + + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=True): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + result = mod.enrich_output_with_assets(output) + + imgs = result["images"] + self.assertNotIn("id", imgs[0]) + self.assertEqual(imgs[1]["id"], "good-ref") + + def test_multiple_output_keys_all_enriched(self): + output = { + "images": [{"filename": "a.png", "subfolder": "", "type": "output"}], + "videos": [{"filename": "b.mp4", "subfolder": "", "type": "output"}], + } + result = _call(output) + self.assertIn("id", result["images"][0]) + self.assertIn("id", result["videos"][0]) + + def test_none_entry_in_list_unchanged(self): + output = {"images": [None, {"filename": "a.png", "subfolder": "", "type": "output"}]} + result = _call(output) + self.assertIsNone(result["images"][0]) + self.assertIn("id", result["images"][1]) + + def test_path_traversal_subfolder_skipped(self): + register_mock = MagicMock(return_value=_make_register_result()) + mocked = _mocked_modules(register_file_in_place=register_mock) + + output = {"images": [{"filename": "passwd", "subfolder": "../../etc", "type": "output"}]} + + # Do NOT patch os.path.abspath — real resolution is required for the containment check. + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=True): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + result = mod.enrich_output_with_assets(output) + + self.assertNotIn("id", result["images"][0]) + register_mock.assert_not_called() + + def test_absolute_filename_skipped(self): + register_mock = MagicMock(return_value=_make_register_result()) + mocked = _mocked_modules(register_file_in_place=register_mock) + + # Absolute filename — os.path.join discards earlier components when a later one is absolute. + absolute_filename = os.path.abspath(os.sep + "etc" + os.sep + "passwd") + output = {"images": [{"filename": absolute_filename, "subfolder": "", "type": "output"}]} + + with patch.dict("sys.modules", mocked), \ + patch("os.path.isfile", return_value=True): + import importlib + import comfy_execution.asset_enrichment as mod + importlib.reload(mod) + result = mod.enrich_output_with_assets(output) + + self.assertNotIn("id", result["images"][0]) + register_mock.assert_not_called() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 814af5c13..f7cb612e4 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -1,5 +1,7 @@ """Unit tests for comfy_execution/jobs.py""" +import pytest + from comfy_execution.jobs import ( JobStatus, is_previewable, @@ -10,9 +12,50 @@ from comfy_execution.jobs import ( get_outputs_summary, apply_sorting, has_3d_extension, + validate_job_id, ) +class TestValidateJobId: + """validate_job_id guards job creation: POST /prompt rejects ids it raises on.""" + + def test_canonical_form_passes_through(self): + cid = "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7" + assert validate_job_id(cid) == cid + + @pytest.mark.parametrize( + "variant", + [ + "A1B2C3D4-E5F6-7A89-B0C1-D2E3F4A5B6C7", # uppercase + "{a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7}", # braced + "urn:uuid:a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7", # URN + "a1b2c3d4e5f67a89b0c1d2e3f4a5b6c7", # bare hex + " a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7 ", # padded + ], + ) + def test_non_canonical_spellings_rejected(self, variant): + # uuid.UUID parses all of these, but accepting them would silently + # rewrite the client's id (history keys, websocket events, and + # /interrupt matching all match the stored form exactly). + with pytest.raises(ValueError): + validate_job_id(variant) + + @pytest.mark.parametrize( + "bad", + ["", "not-a-uuid", "prompt-123", "a1b2c3d4-e5f6-7a89-b0c1", "None"], + ) + def test_non_uuid_strings_rejected(self, bad): + with pytest.raises(ValueError): + validate_job_id(bad) + + @pytest.mark.parametrize("bad", [123, 1.5, True, None, ["a"], {"id": "x"}]) + def test_non_strings_rejected(self, bad): + # uuid.UUID raises AttributeError/TypeError on non-strings; the helper + # must normalize those to ValueError so callers need one except clause. + with pytest.raises(ValueError): + validate_job_id(bad) + + class TestJobStatus: """Test JobStatus constants."""