diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat index ed00583b6..4501ef9a1 100644 --- a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat index 4898a424f..6487ac7ce 100755 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat index 32611e4af..01c5bb33b 100644 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index d72ece2ce..8f07a7b1c 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -20,7 +20,7 @@ jobs: git_tag: ${{ inputs.git_tag }} cache_tag: "cu130" python_minor: "13" - python_patch: "9" + python_patch: "11" rel_name: "nvidia" rel_extra_name: "" test_release: true @@ -65,11 +65,11 @@ jobs: contents: "write" packages: "write" pull-requests: "read" - name: "Release AMD ROCm 7.1.1" + name: "Release AMD ROCm 7.2" uses: ./.github/workflows/stable-release.yml with: git_tag: ${{ inputs.git_tag }} - cache_tag: "rocm711" + cache_tag: "rocm72" python_minor: "12" python_patch: "10" rel_name: "amd" diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 28484a9d1..f501b7b31 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -117,7 +117,7 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* - grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt + grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt ./python.exe -s -m pip install -r requirements_comfyui.txt rm requirements_comfyui.txt diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 419873ad8..9160242e9 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index fd70aff23..581c0474b 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -13,7 +13,7 @@ jobs: - name: Checkout ComfyUI uses: actions/checkout@v4 with: - repository: "comfyanonymous/ComfyUI" + repository: "Comfy-Org/ComfyUI" path: "ComfyUI" - uses: actions/setup-python@v4 with: @@ -32,7 +32,9 @@ jobs: working-directory: ComfyUI - name: Check for unhandled exceptions in server log run: | - if grep -qE "Exception|Error" console_output.log; then + grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log + cat console_output_filtered.log + if grep -qE "Exception|Error" console_output_filtered.log; then echo "Unhandled exception/error found in server log." exit 1 fi diff --git a/.github/workflows/update-ci-container.yml b/.github/workflows/update-ci-container.yml new file mode 100644 index 000000000..f7972e056 --- /dev/null +++ b/.github/workflows/update-ci-container.yml @@ -0,0 +1,59 @@ +name: "CI: Update CI Container" + +on: + release: + types: [published] + workflow_dispatch: + inputs: + version: + description: 'ComfyUI version (e.g., v0.7.0)' + required: true + type: string + +jobs: + update-ci-container: + runs-on: ubuntu-latest + # Skip pre-releases unless manually triggered + if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease + steps: + - name: Get version + id: version + run: | + if [ "${{ github.event_name }}" = "release" ]; then + VERSION="${{ github.event.release.tag_name }}" + else + VERSION="${{ inputs.version }}" + fi + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - name: Checkout comfyui-ci-container + uses: actions/checkout@v4 + with: + repository: comfy-org/comfyui-ci-container + token: ${{ secrets.CI_CONTAINER_PAT }} + + - name: Check current version + id: current + run: | + CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown") + echo "current_version=$CURRENT" >> $GITHUB_OUTPUT + + - name: Update Dockerfile + run: | + VERSION="${{ steps.version.outputs.version }}" + sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile + + - name: Create Pull Request + id: create-pr + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.CI_CONTAINER_PAT }} + branch: automation/comfyui-${{ steps.version.outputs.version }} + title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}" + body: | + Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}` + + **Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }} + + labels: automation + commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}" diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index f61ee21a2..93e01ac93 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -29,7 +29,7 @@ on: description: 'python patch version' required: true type: string - default: "9" + default: "11" # push: # branches: # - master diff --git a/README.md b/README.md index 45d721681..157fa9f70 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Works fully offline: core will never download anything unless you want to. -- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview). +- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes` - [Config file](extra_model_paths.yaml.example) to set the search paths for models. Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) @@ -188,7 +188,7 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp If you have trouble extracting it, right click the file -> properties -> unblock -Update your Nvidia drivers if it doesn't start. +The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start. #### Alternative Downloads: @@ -213,11 +213,11 @@ comfy install ## Manual Install (Windows, Linux) -Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies. +Python 3.14 works but some custom nodes may have issues. The free threaded variant works but some dependencies will enable the GIL so it's not fully supported. Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 -torch 2.4 and above is supported but some features might only work on newer versions. We generally recommend using the latest major version of pytorch unless it is less than 2 weeks old. +torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. ### Instructions: @@ -234,7 +234,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` -This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: +This is the command to install the nightly with ROCm 7.1 which might have some performance improvements: ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1``` @@ -245,7 +245,7 @@ These have less hardware support than the builds above but they work on windows. RDNA 3 (RX 7000 series): -```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/``` +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/``` RDNA 3.5 (Strix halo/Ryzen AI Max+ 365): diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py new file mode 100644 index 000000000..1e10b94dc --- /dev/null +++ b/alembic_db/versions/0001_assets.py @@ -0,0 +1,174 @@ +""" +Initial assets schema +Revision ID: 0001_assets +Revises: None +Create Date: 2025-12-10 00:00:00 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0001_assets" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ASSETS: content identity + op.create_table( + "assets", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("hash", sa.String(length=256), nullable=True), + sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("mime_type", sa.String(length=255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + op.create_index("uq_assets_hash", "assets", ["hash"], unique=True) + op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) + + # ASSETS_INFO: user-visible references + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + # TAGS: normalized tag vocabulary + op.create_table( + "tags", + sa.Column("name", sa.String(length=512), primary_key=True), + sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), + sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), + ) + op.create_index("ix_tags_tag_type", "tags", ["tag_type"]) + + # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + # ASSET_CACHE_STATE: N:1 local cache rows per Asset + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) + + # Tags vocabulary + tags_table = sa.table( + "tags", + sa.column("name", sa.String(length=512)), + sa.column("tag_type", sa.String()), + ) + op.bulk_insert( + tags_table, + [ + {"name": "models", "tag_type": "system"}, + {"name": "input", "tag_type": "system"}, + {"name": "output", "tag_type": "system"}, + + {"name": "configs", "tag_type": "system"}, + {"name": "checkpoints", "tag_type": "system"}, + {"name": "loras", "tag_type": "system"}, + {"name": "vae", "tag_type": "system"}, + {"name": "text_encoders", "tag_type": "system"}, + {"name": "diffusion_models", "tag_type": "system"}, + {"name": "clip_vision", "tag_type": "system"}, + {"name": "style_models", "tag_type": "system"}, + {"name": "embeddings", "tag_type": "system"}, + {"name": "diffusers", "tag_type": "system"}, + {"name": "vae_approx", "tag_type": "system"}, + {"name": "controlnet", "tag_type": "system"}, + {"name": "gligen", "tag_type": "system"}, + {"name": "upscale_models", "tag_type": "system"}, + {"name": "hypernetworks", "tag_type": "system"}, + {"name": "photomaker", "tag_type": "system"}, + {"name": "classifiers", "tag_type": "system"}, + + {"name": "encoder", "tag_type": "system"}, + {"name": "decoder", "tag_type": "system"}, + + {"name": "missing", "tag_type": "system"}, + {"name": "rescan", "tag_type": "system"}, + ], + ) + + +def downgrade() -> None: + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_tags_tag_type", table_name="tags") + op.drop_table("tags") + + op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + op.drop_index("uq_assets_hash", table_name="assets") + op.drop_index("ix_assets_mime_type", table_name="assets") + op.drop_table("assets") diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py new file mode 100644 index 000000000..7676e50b4 --- /dev/null +++ b/app/assets/api/routes.py @@ -0,0 +1,514 @@ +import logging +import uuid +import urllib.parse +import os +import contextlib +from aiohttp import web + +from pydantic import ValidationError + +import app.assets.manager as manager +from app import user_manager +from app.assets.api import schemas_in +from app.assets.helpers import get_query_dict +from app.assets.scanner import seed_assets + +import folder_paths + +ROUTES = web.RouteTableDef() +USER_MANAGER: user_manager.UserManager | None = None + +# UUID regex (canonical hyphenated form, case-insensitive) +UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + +# Note to any custom node developers reading this code: +# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same. + +def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: + global USER_MANAGER + USER_MANAGER = user_manager_instance + app.add_routes(ROUTES) + +def _error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: + return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) + + +def _validation_error_response(code: str, ve: ValidationError) -> web.Response: + return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) + + +@ROUTES.head("/api/assets/hash/{hash}") +async def head_asset_by_hash(request: web.Request) -> web.Response: + hash_str = request.match_info.get("hash", "").strip().lower() + if not hash_str or ":" not in hash_str: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = hash_str.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + exists = manager.asset_exists(asset_hash=hash_str) + return web.Response(status=200 if exists else 404) + + +@ROUTES.get("/api/assets") +async def list_assets(request: web.Request) -> web.Response: + """ + GET request to list assets. + """ + query_dict = get_query_dict(request) + try: + q = schemas_in.ListAssetsQuery.model_validate(query_dict) + except ValidationError as ve: + return _validation_error_response("INVALID_QUERY", ve) + + payload = manager.list_assets( + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + offset=q.offset, + sort=q.sort, + order=q.order, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + return web.json_response(payload.model_dump(mode="json", exclude_none=True)) + + +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") +async def get_asset(request: web.Request) -> web.Response: + """ + GET request to get an asset's info as JSON. + """ + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + result = manager.get_asset( + asset_info_id=asset_info_id, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as e: + return _error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}) + except Exception: + logging.exception( + "get_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +async def download_asset_content(request: web.Request) -> web.Response: + # question: do we need disposition? could we just stick with one of these? + disposition = request.query.get("disposition", "attachment").lower().strip() + if disposition not in {"inline", "attachment"}: + disposition = "attachment" + + try: + abs_path, content_type, filename = manager.resolve_asset_content_for_download( + asset_info_id=str(uuid.UUID(request.match_info["id"])), + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve)) + except NotImplementedError as nie: + return _error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + except FileNotFoundError: + return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + + quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") + cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + + file_size = os.path.getsize(abs_path) + logging.info( + "download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s", + abs_path, + file_size, + file_size / (1024 * 1024), + content_type, + filename, + ) + + async def file_sender(): + chunk_size = 64 * 1024 + with open(abs_path, "rb") as f: + while True: + chunk = f.read(chunk_size) + if not chunk: + break + yield chunk + + return web.Response( + body=file_sender(), + content_type=content_type, + headers={ + "Content-Disposition": cd, + "Content-Length": str(file_size), + }, + ) + + +@ROUTES.post("/api/assets/from-hash") +async def create_asset_from_hash(request: web.Request) -> web.Response: + try: + payload = await request.json() + body = schemas_in.CreateFromHashBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + result = manager.create_asset_from_hash( + hash_str=body.hash, + name=body.name, + tags=body.tags, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") + return web.json_response(result.model_dump(mode="json"), status=201) + + +@ROUTES.post("/api/assets") +async def upload_asset(request: web.Request) -> web.Response: + """Multipart/form-data endpoint for Asset uploads.""" + if not (request.content_type or "").lower().startswith("multipart/"): + return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") + + reader = await request.multipart() + + file_present = False + file_client_name: str | None = None + tags_raw: list[str] = [] + provided_name: str | None = None + user_metadata_raw: str | None = None + provided_hash: str | None = None + provided_hash_exists: bool | None = None + + file_written = 0 + tmp_path: str | None = None + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + + if s: + if ":" not in s: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + provided_hash = f"{algo}:{digest}" + try: + provided_hash_exists = manager.asset_exists(asset_hash=provided_hash) + except Exception: + provided_hash_exists = None # do not fail the whole request here + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # If client supplied a hash that we know exists, drain but do not write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") + continue # Do not create temp file; we will create AssetInfo from the existing content + + # Otherwise, store to temp for hashing/ingest + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + try: + if os.path.exists(tmp_path or ""): + os.remove(tmp_path) + finally: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + + # If client did not send file, and we are not doing a from-hash fast path -> error + if not file_present and not (provided_hash and provided_hash_exists): + return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") + + if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): + # Empty upload is only acceptable if we are fast-pathing from existing hash + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + try: + spec = schemas_in.UploadAssetSpec.model_validate({ + "tags": tags_raw, + "name": provided_name, + "user_metadata": user_metadata_raw, + "hash": provided_hash, + }) + except ValidationError as ve: + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _validation_error_response("INVALID_BODY", ve) + + # Validate models category against configured folders (consistent with previous behavior) + if spec.tags and spec.tags[0] == "models": + if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + return _error_response( + 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" + ) + + owner_id = USER_MANAGER.get_request_user_id(request) + + # Fast path: if a valid provided hash exists, create AssetInfo without writing anything + if spec.hash and provided_hash_exists is True: + try: + result = manager.create_asset_from_hash( + hash_str=spec.hash, + name=spec.name or (spec.hash.split(":", 1)[1]), + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + owner_id=owner_id, + ) + except Exception: + logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") + + # Drain temp if we accidentally saved (e.g., hash field came after file) + if tmp_path and os.path.exists(tmp_path): + with contextlib.suppress(Exception): + os.remove(tmp_path) + + status = 200 if (not result.created_new) else 201 + return web.json_response(result.model_dump(mode="json"), status=status) + + # Otherwise, we must have a temp file path to ingest + if not tmp_path or not os.path.exists(tmp_path): + # The only case we reach here without a temp file is: client sent a hash that does not exist and no file + return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") + + try: + created = manager.upload_asset_from_temp_path( + spec, + temp_path=tmp_path, + client_filename=file_client_name, + owner_id=owner_id, + expected_asset_hash=spec.hash, + ) + status = 201 if created.created_new else 200 + return web.json_response(created.model_dump(mode="json"), status=status) + except ValueError as e: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + msg = str(e) + if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": + return _error_response( + 400, + "HASH_MISMATCH", + "Uploaded file hash does not match provided hash.", + ) + return _error_response(400, "BAD_REQUEST", "Invalid inputs.") + except Exception: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + +@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") +async def update_asset(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + body = schemas_in.UpdateAssetBody.model_validate(await request.json()) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = manager.update_asset( + asset_info_id=asset_info_id, + name=body.name, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except (ValueError, PermissionError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + logging.exception( + "update_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") +async def delete_asset(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + delete_content = request.query.get("delete_content") + delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} + + try: + deleted = manager.delete_asset_reference( + asset_info_id=asset_info_id, + owner_id=USER_MANAGER.get_request_user_id(request), + delete_content_if_orphan=delete_content, + ) + except Exception: + logging.exception( + "delete_asset_reference failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if not deleted: + return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return web.Response(status=204) + + +@ROUTES.get("/api/tags") +async def get_tags(request: web.Request) -> web.Response: + """ + GET request to list all tags based on query parameters. + """ + query_map = dict(request.rel_url.query) + + try: + query = schemas_in.TagsListQuery.model_validate(query_map) + except ValidationError as e: + return web.json_response( + {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}}, + status=400, + ) + + result = manager.list_tags( + prefix=query.prefix, + limit=query.limit, + offset=query.offset, + order=query.order, + include_zero=query.include_zero, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + return web.json_response(result.model_dump(mode="json")) + + +@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +async def add_asset_tags(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + payload = await request.json() + data = schemas_in.TagsAdd.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = manager.add_tags_to_asset( + asset_info_id=asset_info_id, + tags=data.tags, + origin="manual", + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except (ValueError, PermissionError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + logging.exception( + "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +async def delete_asset_tags(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + payload = await request.json() + data = schemas_in.TagsRemove.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = manager.remove_tags_from_asset( + asset_info_id=asset_info_id, + tags=data.tags, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + logging.exception( + "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.post("/api/assets/seed") +async def seed_assets_endpoint(request: web.Request) -> web.Response: + """Trigger asset seeding for specified roots (models, input, output).""" + try: + payload = await request.json() + roots = payload.get("roots", ["models", "input", "output"]) + except Exception: + roots = ["models", "input", "output"] + + valid_roots = [r for r in roots if r in ("models", "input", "output")] + if not valid_roots: + return _error_response(400, "INVALID_BODY", "No valid roots specified") + + try: + seed_assets(tuple(valid_roots)) + except Exception: + logging.exception("seed_assets failed for roots=%s", valid_roots) + return _error_response(500, "INTERNAL", "Seed operation failed") + + return web.json_response({"seeded": valid_roots}, status=200) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py new file mode 100644 index 000000000..6707ffb0c --- /dev/null +++ b/app/assets/api/schemas_in.py @@ -0,0 +1,264 @@ +import json +from typing import Any, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + conint, + field_validator, + model_validator, +) + +class ListAssetsQuery(BaseModel): + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: str | None = None + + # Accept either a JSON string (query param) or a dict + metadata_filter: dict[str, Any] | None = None + + limit: conint(ge=1, le=500) = 20 + offset: conint(ge=0) = 0 + + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + order: Literal["asc", "desc"] = "desc" + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + # Accept "a,b,c" or ["a","b"] (we are liberal in what we accept) + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None + + +class UpdateAssetBody(BaseModel): + name: str | None = None + user_metadata: dict[str, Any] | None = None + + @model_validator(mode="after") + def _at_least_one(self): + if self.name is None and self.user_metadata is None: + raise ValueError("Provide at least one of: name, user_metadata.") + return self + + +class CreateFromHashBody(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + hash: str + name: str + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + + @field_validator("hash") + @classmethod + def _require_blake3(cls, v): + s = (v or "").strip().lower() + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return s + + @field_validator("tags", mode="before") + @classmethod + def _tags_norm(cls, v): + if v is None: + return [] + if isinstance(v, list): + out = [str(t).strip().lower() for t in v if str(t).strip()] + seen = set() + dedup = [] + for t in out: + if t not in seen: + seen.add(t) + dedup.append(t) + return dedup + if isinstance(v, str): + return [t.strip().lower() for t in v.split(",") if t.strip()] + return [] + + +class TagsListQuery(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + prefix: str | None = Field(None, min_length=1, max_length=256) + limit: int = Field(100, ge=1, le=1000) + offset: int = Field(0, ge=0, le=10_000_000) + order: Literal["count_desc", "name_asc"] = "count_desc" + include_zero: bool = True + + @field_validator("prefix") + @classmethod + def normalize_prefix(cls, v: str | None) -> str | None: + if v is None: + return v + v = v.strip() + return v.lower() or None + + +class TagsAdd(BaseModel): + model_config = ConfigDict(extra="ignore") + tags: list[str] = Field(..., min_length=1) + + @field_validator("tags") + @classmethod + def normalize_tags(cls, v: list[str]) -> list[str]: + out = [] + for t in v: + if not isinstance(t, str): + raise TypeError("tags must be strings") + tnorm = t.strip().lower() + if tnorm: + out.append(tnorm) + seen = set() + deduplicated = [] + for x in out: + if x not in seen: + seen.add(x) + deduplicated.append(x) + return deduplicated + + +class TagsRemove(TagsAdd): + pass + + +class UploadAssetSpec(BaseModel): + """Upload Asset operation. + - tags: ordered; first is root ('models'|'input'|'output'); + if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths + - name: display name + - user_metadata: arbitrary JSON object (optional) + - hash: optional canonical 'blake3:' provided by the client for validation / fast-path + + Files created via this endpoint are stored on disk using the **content hash** as the filename stem + and the original extension is preserved when available. + """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + tags: list[str] = Field(..., min_length=1) + name: str | None = Field(default=None, max_length=512, description="Display Name") + user_metadata: dict[str, Any] = Field(default_factory=dict) + hash: str | None = Field(default=None) + + @field_validator("hash", mode="before") + @classmethod + def _parse_hash(cls, v): + if v is None: + return None + s = str(v).strip().lower() + if not s: + return None + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return f"{algo}:{digest}" + + @field_validator("tags", mode="before") + @classmethod + def _parse_tags(cls, v): + """ + Accepts a list of strings (possibly multiple form fields), + where each string can be: + - JSON array (e.g., '["models","loras","foo"]') + - comma-separated ('models, loras, foo') + - single token ('models') + Returns a normalized, deduplicated, ordered list. + """ + items: list[str] = [] + if v is None: + return [] + if isinstance(v, str): + v = [v] + + if isinstance(v, list): + for item in v: + if item is None: + continue + s = str(item).strip() + if not s: + continue + if s.startswith("["): + try: + arr = json.loads(s) + if isinstance(arr, list): + items.extend(str(x) for x in arr) + continue + except Exception: + pass # fallback to CSV parse below + items.extend([p for p in s.split(",") if p.strip()]) + else: + return [] + + # normalize + dedupe + norm = [] + seen = set() + for t in items: + tnorm = str(t).strip().lower() + if tnorm and tnorm not in seen: + seen.add(tnorm) + norm.append(tnorm) + return norm + + @field_validator("user_metadata", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v or {} + if isinstance(v, str): + s = v.strip() + if not s: + return {} + try: + parsed = json.loads(s) + except Exception as e: + raise ValueError(f"user_metadata must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("user_metadata must be a JSON object") + return parsed + return {} + + @model_validator(mode="after") + def _validate_order(self): + if not self.tags: + raise ValueError("tags must be provided and non-empty") + root = self.tags[0] + if root not in {"models", "input", "output"}: + raise ValueError("first tag must be one of: models, input, output") + if root == "models": + if len(self.tags) < 2: + raise ValueError("models uploads require a category tag as the second tag") + return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py new file mode 100644 index 000000000..b6fb3da0c --- /dev/null +++ b/app/assets/api/schemas_out.py @@ -0,0 +1,93 @@ +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, field_serializer + + +class AssetSummary(BaseModel): + id: str + name: str + asset_hash: str | None = None + size: int | None = None + mime_type: str | None = None + tags: list[str] = Field(default_factory=list) + preview_url: str | None = None + created_at: datetime | None = None + updated_at: datetime | None = None + last_access_time: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "updated_at", "last_access_time") + def _ser_dt(self, v: datetime | None, _info): + return v.isoformat() if v else None + + +class AssetsList(BaseModel): + assets: list[AssetSummary] + total: int + has_more: bool + + +class AssetUpdated(BaseModel): + id: str + name: str + asset_hash: str | None = None + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + updated_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("updated_at") + def _ser_updated(self, v: datetime | None, _info): + return v.isoformat() if v else None + + +class AssetDetail(BaseModel): + id: str + name: str + asset_hash: str | None = None + size: int | None = None + mime_type: str | None = None + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + preview_id: str | None = None + created_at: datetime | None = None + last_access_time: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "last_access_time") + def _ser_dt(self, v: datetime | None, _info): + return v.isoformat() if v else None + + +class AssetCreated(AssetDetail): + created_new: bool + + +class TagUsage(BaseModel): + name: str + count: int + type: str + + +class TagsList(BaseModel): + tags: list[TagUsage] = Field(default_factory=list) + total: int + has_more: bool + + +class TagsAdd(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + added: list[str] = Field(default_factory=list) + already_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) + + +class TagsRemove(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + removed: list[str] = Field(default_factory=list) + not_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) diff --git a/app/assets/database/bulk_ops.py b/app/assets/database/bulk_ops.py new file mode 100644 index 000000000..c7b75290a --- /dev/null +++ b/app/assets/database/bulk_ops.py @@ -0,0 +1,204 @@ +import os +import uuid +import sqlalchemy +from typing import Iterable +from sqlalchemy.orm import Session +from sqlalchemy.dialects import sqlite + +from app.assets.helpers import utcnow +from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta + +MAX_BIND_PARAMS = 800 + +def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]: + if not rows: + return [] + rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row)) + for i in range(0, len(rows), rows_per_stmt): + yield rows[i:i + rows_per_stmt] + +def _iter_chunks(seq, n: int): + for i in range(0, len(seq), n): + yield seq[i:i + n] + +def _rows_per_stmt(cols: int) -> int: + return max(1, MAX_BIND_PARAMS // max(1, cols)) + + +def seed_from_paths_batch( + session: Session, + *, + specs: list[dict], + owner_id: str = "", +) -> dict: + """Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + """ + if not specs: + return {"inserted_infos": 0, "won_states": 0, "lost_states": 0} + + now = utcnow() + asset_rows: list[dict] = [] + state_rows: list[dict] = [] + path_to_asset: dict[str, str] = {} + asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row + path_list: list[str] = [] + + for sp in specs: + ap = os.path.abspath(sp["abs_path"]) + aid = str(uuid.uuid4()) + iid = str(uuid.uuid4()) + path_list.append(ap) + path_to_asset[ap] = aid + + asset_rows.append( + { + "id": aid, + "hash": None, + "size_bytes": sp["size_bytes"], + "mime_type": None, + "created_at": now, + } + ) + state_rows.append( + { + "asset_id": aid, + "file_path": ap, + "mtime_ns": sp["mtime_ns"], + } + ) + asset_to_info[aid] = { + "id": iid, + "owner_id": owner_id, + "name": sp["info_name"], + "asset_id": aid, + "preview_id": None, + "user_metadata": {"filename": sp["fname"]} if sp["fname"] else None, + "created_at": now, + "updated_at": now, + "last_access_time": now, + "_tags": sp["tags"], + "_filename": sp["fname"], + } + + # insert all seed Assets (hash=NULL) + ins_asset = sqlite.insert(Asset) + for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)): + session.execute(ins_asset, chunk) + + # try to claim AssetCacheState (file_path) + # Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted + ins_state = ( + sqlite.insert(AssetCacheState) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)): + session.execute(ins_state, chunk) + + # Query to find which of our paths won (were actually inserted) + winners_by_path: set[str] = set() + for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS): + result = session.execute( + sqlalchemy.select(AssetCacheState.file_path) + .where(AssetCacheState.file_path.in_(chunk)) + .where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk])) + ) + winners_by_path.update(result.scalars().all()) + + all_paths_set = set(path_list) + losers_by_path = all_paths_set - winners_by_path + lost_assets = [path_to_asset[p] for p in losers_by_path] + if lost_assets: # losers get their Asset removed + for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS): + session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk))) + + if not winners_by_path: + return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} + + # insert AssetInfo only for winners + # Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted + winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] + ins_info = ( + sqlite.insert(AssetInfo) + .on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name]) + ) + for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)): + session.execute(ins_info, chunk) + + # Query to find which info rows were actually inserted (by matching our generated IDs) + all_info_ids = [row["id"] for row in winner_info_rows] + inserted_info_ids: set[str] = set() + for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS): + result = session.execute( + sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk)) + ) + inserted_info_ids.update(result.scalars().all()) + + # build and insert tag + meta rows for the AssetInfo + tag_rows: list[dict] = [] + meta_rows: list[dict] = [] + if inserted_info_ids: + for row in winner_info_rows: + iid = row["id"] + if iid not in inserted_info_ids: + continue + for t in row["_tags"]: + tag_rows.append({ + "asset_info_id": iid, + "tag_name": t, + "origin": "automatic", + "added_at": now, + }) + if row["_filename"]: + meta_rows.append( + { + "asset_info_id": iid, + "key": "filename", + "ordinal": 0, + "val_str": row["_filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS) + return { + "inserted_infos": len(inserted_info_ids), + "won_states": len(winners_by_path), + "lost_states": len(losers_by_path), + } + + +def bulk_insert_tags_and_meta( + session: Session, + *, + tag_rows: list[dict], + meta_rows: list[dict], + max_bind_params: int, +) -> None: + """Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING. + - tag_rows keys: asset_info_id, tag_name, origin, added_at + - meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json + """ + if tag_rows: + ins_links = ( + sqlite.insert(AssetInfoTag) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params): + session.execute(ins_links, chunk) + if meta_rows: + ins_meta = ( + sqlite.insert(AssetInfoMeta) + .on_conflict_do_nothing( + index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + ) + ) + for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params): + session.execute(ins_meta, chunk) diff --git a/app/assets/database/models.py b/app/assets/database/models.py new file mode 100644 index 000000000..3cd28f68b --- /dev/null +++ b/app/assets/database/models.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import uuid +from datetime import datetime + +from typing import Any +from sqlalchemy import ( + JSON, + BigInteger, + Boolean, + CheckConstraint, + DateTime, + ForeignKey, + Index, + Integer, + Numeric, + String, + Text, + UniqueConstraint, +) +from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship + +from app.assets.helpers import utcnow +from app.database.models import to_dict, Base + + +class Asset(Base): + __tablename__ = "assets" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + hash: Mapped[str | None] = mapped_column(String(256), nullable=True) + size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + mime_type: Mapped[str | None] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=utcnow + ) + + infos: Mapped[list[AssetInfo]] = relationship( + "AssetInfo", + back_populates="asset", + primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), + foreign_keys=lambda: [AssetInfo.asset_id], + cascade="all,delete-orphan", + passive_deletes=True, + ) + + preview_of: Mapped[list[AssetInfo]] = relationship( + "AssetInfo", + back_populates="preview_asset", + primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id), + foreign_keys=lambda: [AssetInfo.preview_id], + viewonly=True, + ) + + cache_states: Mapped[list[AssetCacheState]] = relationship( + back_populates="asset", + cascade="all, delete-orphan", + passive_deletes=True, + ) + + __table_args__ = ( + Index("uq_assets_hash", "hash", unique=True), + Index("ix_assets_mime_type", "mime_type"), + CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetCacheState(Base): + __tablename__ = "asset_cache_state" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) + file_path: Mapped[str] = mapped_column(Text, nullable=False) + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + + asset: Mapped[Asset] = relationship(back_populates="cache_states") + + __table_args__ = ( + Index("ix_asset_cache_state_file_path", "file_path"), + Index("ix_asset_cache_state_asset_id", "asset_id"), + CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + return to_dict(self, include_none=include_none) + + def __repr__(self) -> str: + return f"" + + +class AssetInfo(Base): + __tablename__ = "assets_info" + + id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") + name: Mapped[str] = mapped_column(String(512), nullable=False) + asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) + preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) + + asset: Mapped[Asset] = relationship( + "Asset", + back_populates="infos", + foreign_keys=[asset_id], + lazy="selectin", + ) + preview_asset: Mapped[Asset | None] = relationship( + "Asset", + back_populates="preview_of", + foreign_keys=[preview_id], + ) + + metadata_entries: Mapped[list[AssetInfoMeta]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + ) + + tag_links: Mapped[list[AssetInfoTag]] = relationship( + back_populates="asset_info", + cascade="all,delete-orphan", + passive_deletes=True, + overlaps="tags,asset_infos", + ) + + tags: Mapped[list[Tag]] = relationship( + secondary="asset_info_tags", + back_populates="asset_infos", + lazy="selectin", + viewonly=True, + overlaps="tag_links,asset_info_links,asset_infos,tag", + ) + + __table_args__ = ( + UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + Index("ix_assets_info_owner_name", "owner_id", "name"), + Index("ix_assets_info_owner_id", "owner_id"), + Index("ix_assets_info_asset_id", "asset_id"), + Index("ix_assets_info_name", "name"), + Index("ix_assets_info_created_at", "created_at"), + Index("ix_assets_info_last_access_time", "last_access_time"), + ) + + def to_dict(self, include_none: bool = False) -> dict[str, Any]: + data = to_dict(self, include_none=include_none) + data["tags"] = [t.name for t in self.tags] + return data + + def __repr__(self) -> str: + return f"" + + +class AssetInfoMeta(Base): + __tablename__ = "asset_info_meta" + + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + key: Mapped[str] = mapped_column(String(256), primary_key=True) + ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) + + val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True) + val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True) + val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True) + val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True) + + asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries") + + __table_args__ = ( + Index("ix_asset_info_meta_key", "key"), + Index("ix_asset_info_meta_key_val_str", "key", "val_str"), + Index("ix_asset_info_meta_key_val_num", "key", "val_num"), + Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), + ) + + +class AssetInfoTag(Base): + __tablename__ = "asset_info_tags" + + asset_info_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True + ) + tag_name: Mapped[str] = mapped_column( + String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True + ) + origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") + added_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=utcnow + ) + + asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links") + tag: Mapped[Tag] = relationship(back_populates="asset_info_links") + + __table_args__ = ( + Index("ix_asset_info_tags_tag_name", "tag_name"), + Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), + ) + + +class Tag(Base): + __tablename__ = "tags" + + name: Mapped[str] = mapped_column(String(512), primary_key=True) + tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") + + asset_info_links: Mapped[list[AssetInfoTag]] = relationship( + back_populates="tag", + overlaps="asset_infos,tags", + ) + asset_infos: Mapped[list[AssetInfo]] = relationship( + secondary="asset_info_tags", + back_populates="tags", + viewonly=True, + overlaps="asset_info_links,tag_links,tags,asset_info", + ) + + __table_args__ = ( + Index("ix_tags_tag_type", "tag_type"), + ) + + def __repr__(self) -> str: + return f"" diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py new file mode 100644 index 000000000..d6b33ec7b --- /dev/null +++ b/app/assets/database/queries.py @@ -0,0 +1,976 @@ +import os +import logging +import sqlalchemy as sa +from collections import defaultdict +from datetime import datetime +from typing import Iterable, Any +from sqlalchemy import select, delete, exists, func +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, contains_eager, noload +from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag +from app.assets.helpers import ( + compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow +) +from typing import Sequence + + +def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) + + +def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: + """ + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. + """ + alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_info_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + no_row_for_key = sa.not_( + sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + ) + ) + null_row = _exists_for_pred( + key, + AssetInfoMeta.val_json.is_(None), + AssetInfoMeta.val_str.is_(None), + AssetInfoMeta.val_num.is_(None), + AssetInfoMeta.val_bool.is_(None), + ) + return sa.or_(no_row_for_key, null_row) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value)) + if isinstance(value, (int, float)): + from decimal import Decimal + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetInfoMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetInfoMeta.val_str == value) + return _exists_for_pred(key, AssetInfoMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt + + +def asset_exists_by_hash( + session: Session, + *, + asset_hash: str, +) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + row = ( + session.execute( + select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).first() + return row is not None + + +def asset_info_exists_for_asset_id( + session: Session, + *, + asset_id: str, +) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetInfo) + .where(AssetInfo.asset_id == asset_id) + .limit(1) + ) + return (session.execute(q)).first() is not None + + +def get_asset_by_hash( + session: Session, + *, + asset_hash: str, +) -> Asset | None: + return ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + + +def get_asset_info_by_id( + session: Session, + *, + asset_info_id: str, +) -> AssetInfo | None: + return session.get(AssetInfo, asset_info_id) + + +def list_asset_infos_page( + session: Session, + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> tuple[list[AssetInfo], dict[str, list[str]], int]: + base = ( + select(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) + .where(visible_owner_clause(owner_id)) + ) + + if name_contains: + escaped, esc = escape_like_prefix(name_contains) + base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) + + base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetInfo.name, + "created_at": AssetInfo.created_at, + "updated_at": AssetInfo.updated_at, + "last_access_time": AssetInfo.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetInfo.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(sa.func.count()) + .select_from(AssetInfo) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where(visible_owner_clause(owner_id)) + ) + if name_contains: + escaped, esc = escape_like_prefix(name_contains) + count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc)) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_metadata_filter(count_stmt, metadata_filter) + + total = int((session.execute(count_stmt)).scalar_one() or 0) + + infos = (session.execute(base)).unique().scalars().all() + + id_list: list[str] = [i.id for i in infos] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = session.execute( + select(AssetInfoTag.asset_info_id, Tag.name) + .join(Tag, Tag.name == AssetInfoTag.tag_name) + .where(AssetInfoTag.asset_info_id.in_(id_list)) + .order_by(AssetInfoTag.added_at) + ) + for aid, tag_name in rows.all(): + tag_map[aid].append(tag_name) + + return infos, tag_map, total + + +def fetch_asset_info_asset_and_tags( + session: Session, + asset_info_id: str, + owner_id: str = "", +) -> tuple[AssetInfo, Asset, list[str]] | None: + stmt = ( + select(AssetInfo, Asset, Tag.name) + .join(Asset, Asset.id == AssetInfo.asset_id) + .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) + .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .options(noload(AssetInfo.tags)) + .order_by(Tag.name.asc()) + ) + + rows = (session.execute(stmt)).all() + if not rows: + return None + + first_info, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _info, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_info, first_asset, tags + + +def fetch_asset_info_and_asset( + session: Session, + *, + asset_info_id: str, + owner_id: str = "", +) -> tuple[AssetInfo, Asset] | None: + stmt = ( + select(AssetInfo, Asset) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetInfo.tags)) + ) + row = session.execute(stmt) + pair = row.first() + if not pair: + return None + return pair[0], pair[1] + +def list_cache_states_by_asset_id( + session: Session, *, asset_id: str +) -> Sequence[AssetCacheState]: + return ( + session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() + + +def touch_asset_info_by_id( + session: Session, + *, + asset_info_id: str, + ts: datetime | None = None, + only_if_newer: bool = True, +) -> None: + ts = ts or utcnow() + stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) + if only_if_newer: + stmt = stmt.where( + sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) + ) + session.execute(stmt.values(last_access_time=ts)) + + +def create_asset_info_for_existing_asset( + session: Session, + *, + asset_hash: str, + name: str, + user_metadata: dict | None = None, + tags: Sequence[str] | None = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetInfo: + """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" + now = utcnow() + asset = get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"Unknown asset hash {asset_hash}") + + info = AssetInfo( + owner_id=owner_id, + name=name, + asset_id=asset.id, + preview_id=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + try: + with session.begin_nested(): + session.add(info) + session.flush() + except IntegrityError: + existing = ( + session.execute( + select(AssetInfo) + .options(noload(AssetInfo.tags)) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == name, + AssetInfo.owner_id == owner_id, + ) + .limit(1) + ) + ).unique().scalars().first() + if not existing: + raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") + return existing + + # metadata["filename"] hack + new_meta = dict(user_metadata or {}) + computed_filename = None + try: + p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) + if p: + computed_filename = compute_relative_filename(p) + except Exception: + computed_filename = None + if computed_filename: + new_meta["filename"] = computed_filename + if new_meta: + replace_asset_info_metadata_projection( + session, + asset_info_id=info.id, + user_metadata=new_meta, + ) + + if tags is not None: + set_asset_info_tags( + session, + asset_info_id=info.id, + tags=tags, + origin=tag_origin, + ) + return info + + +def set_asset_info_tags( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> dict: + desired = normalize_tags(tags) + + current = set( + tag_name for (tag_name,) in ( + session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + ).all() + ) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + ensure_tags_exist(session, to_add, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) + for t in to_add + ]) + session.flush() + + if to_remove: + session.execute( + delete(AssetInfoTag) + .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + ) + session.flush() + + return {"added": to_add, "removed": to_remove, "total": desired} + + +def replace_asset_info_metadata_projection( + session: Session, + *, + asset_info_id: str, + user_metadata: dict | None = None, +) -> None: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info.user_metadata = user_metadata or {} + info.updated_at = utcnow() + session.flush() + + session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) + session.flush() + + if not user_metadata: + return + + rows: list[AssetInfoMeta] = [] + for k, v in user_metadata.items(): + for r in project_kv(k, v): + rows.append( + AssetInfoMeta( + asset_info_id=asset_info_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + session.flush() + + +def ingest_fs_asset( + session: Session, + *, + asset_hash: str, + abs_path: str, + size_bytes: int, + mtime_ns: int, + mime_type: str | None = None, + info_name: str | None = None, + owner_id: str = "", + preview_id: str | None = None, + user_metadata: dict | None = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> dict: + """ + Idempotently upsert: + - Asset by content hash (create if missing) + - AssetCacheState(file_path) pointing to asset_id + - Optionally AssetInfo + tag links and metadata projection + Returns flags and ids. + """ + locator = os.path.abspath(abs_path) + now = utcnow() + + if preview_id: + if not session.get(Asset, preview_id): + preview_id = None + + out: dict[str, Any] = { + "asset_created": False, + "asset_updated": False, + "state_created": False, + "state_updated": False, + "asset_info_id": None, + } + + # 1) Asset by hash + asset = ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + if not asset: + vals = { + "hash": asset_hash, + "size_bytes": int(size_bytes), + "mime_type": mime_type, + "created_at": now, + } + res = session.execute( + sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + if int(res.rowcount or 0) > 0: + out["asset_created"] = True + asset = ( + session.execute( + select(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).scalars().first() + if not asset: + raise RuntimeError("Asset row not found after upsert.") + else: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + out["asset_updated"] = True + + # 2) AssetCacheState upsert by file_path (unique) + vals = { + "asset_id": asset.id, + "file_path": locator, + "mtime_ns": int(mtime_ns), + } + ins = ( + sqlite.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + + res = session.execute(ins) + if int(res.rowcount or 0) > 0: + out["state_created"] = True + else: + upd = ( + sa.update(AssetCacheState) + .where(AssetCacheState.file_path == locator) + .where( + sa.or_( + AssetCacheState.asset_id != asset.id, + AssetCacheState.mtime_ns.is_(None), + AssetCacheState.mtime_ns != int(mtime_ns), + ) + ) + .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) + ) + res2 = session.execute(upd) + if int(res2.rowcount or 0) > 0: + out["state_updated"] = True + + # 3) Optional AssetInfo + tags + metadata + if info_name: + try: + with session.begin_nested(): + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_id=asset.id, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + session.flush() + out["asset_info_id"] = info.id + except IntegrityError: + pass + + existing_info = ( + session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == info_name, + (AssetInfo.owner_id == owner_id), + ) + .limit(1) + ) + ).unique().scalar_one_or_none() + if not existing_info: + raise RuntimeError("Failed to update or insert AssetInfo.") + + if preview_id and existing_info.preview_id != preview_id: + existing_info.preview_id = preview_id + + existing_info.updated_at = now + if existing_info.last_access_time < now: + existing_info.last_access_time = now + session.flush() + out["asset_info_id"] = existing_info.id + + norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] + if norm and out["asset_info_id"] is not None: + if not require_existing_tags: + ensure_tags_exist(session, norm, tag_type="user") + + existing_tag_names = set( + name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() + ) + missing = [t for t in norm if t not in existing_tag_names] + if missing and require_existing_tags: + raise ValueError(f"Unknown tags: {missing}") + + existing_links = set( + tag_name + for (tag_name,) in ( + session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) + ) + ).all() + ) + to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] + if to_add: + session.add_all( + [ + AssetInfoTag( + asset_info_id=out["asset_info_id"], + tag_name=t, + origin=tag_origin, + added_at=now, + ) + for t in to_add + ] + ) + session.flush() + + # metadata["filename"] hack + if out["asset_info_id"] is not None: + primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) + computed_filename = compute_relative_filename(primary_path) if primary_path else None + + current_meta = existing_info.user_metadata or {} + new_meta = dict(current_meta) + if user_metadata is not None: + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + replace_asset_info_metadata_projection( + session, + asset_info_id=out["asset_info_id"], + user_metadata=new_meta, + ) + + try: + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + return out + + +def update_asset_info_full( + session: Session, + *, + asset_info_id: str, + name: str | None = None, + tags: Sequence[str] | None = None, + user_metadata: dict | None = None, + tag_origin: str = "manual", + asset_info_row: Any = None, +) -> AssetInfo: + if not asset_info_row: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + else: + info = asset_info_row + + touched = False + if name is not None and name != info.name: + info.name = name + touched = True + + computed_filename = None + try: + p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id)) + if p: + computed_filename = compute_relative_filename(p) + except Exception: + computed_filename = None + + if user_metadata is not None: + new_meta = dict(user_metadata) + if computed_filename: + new_meta["filename"] = computed_filename + replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + else: + if computed_filename: + current_meta = info.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + new_meta["filename"] = computed_filename + replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + set_asset_info_tags( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + info.updated_at = utcnow() + session.flush() + + return info + + +def delete_asset_info_by_id( + session: Session, + *, + asset_info_id: str, + owner_id: str, +) -> bool: + stmt = sa.delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + return int((session.execute(stmt)).rowcount or 0) > 0 + + +def list_tags_with_usage( + session: Session, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetInfoTag.tag_name.label("tag_name"), + func.count(AssetInfoTag.asset_info_id).label("cnt"), + ) + .select_from(AssetInfoTag) + .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) + .where(visible_owner_clause(owner_id)) + .group_by(AssetInfoTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_like_prefix(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + total_q = total_q.where( + Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) + ) + + rows = (session.execute(q.limit(limit).offset(offset))).all() + total = (session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + session.execute(ins) + + +def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]: + return [ + tag_name for (tag_name,) in ( + session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + ] + + +def add_tags_to_asset_info( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + asset_info_row: Any = None, +) -> dict: + if not asset_info_row: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": [], "already_present": [], "total_tags": total} + + if create_if_missing: + ensure_tags_exist(session, norm, tag_type="user") + + current = { + tag_name + for (tag_name,) in ( + session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_at=utcnow(), + ) + for t in to_add + ] + ) + session.flush() + except IntegrityError: + nested.rollback() + + after = set(get_asset_tags(session, asset_info_id=asset_info_id)) + return { + "added": sorted(((after - current) & want)), + "already_present": sorted(want & current), + "total_tags": sorted(after), + } + + +def remove_tags_from_asset_info( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], +) -> dict: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": [], "not_present": [], "total_tags": total} + + existing = { + tag_name + for (tag_name,) in ( + session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + session.execute( + delete(AssetInfoTag) + .where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": to_remove, "not_present": not_present, "total_tags": total} + + +def remove_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, +) -> None: + session.execute( + sa.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), + AssetInfoTag.tag_name == "missing", + ) + ) + + +def set_asset_info_preview( + session: Session, + *, + asset_info_id: str, + preview_asset_id: str | None = None, +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + if preview_asset_id is None: + info.preview_id = None + else: + # validate preview asset exists + if not session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + info.preview_id = preview_asset_id + + info.updated_at = utcnow() + session.flush() diff --git a/app/assets/database/tags.py b/app/assets/database/tags.py new file mode 100644 index 000000000..3ab6497c2 --- /dev/null +++ b/app/assets/database/tags.py @@ -0,0 +1,62 @@ +from typing import Iterable + +import sqlalchemy +from sqlalchemy.orm import Session +from sqlalchemy.dialects import sqlite + +from app.assets.helpers import normalize_tags, utcnow +from app.assets.database.models import Tag, AssetInfoTag, AssetInfo + + +def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + return session.execute(ins) + +def add_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sqlalchemy.select( + AssetInfo.id.label("asset_info_id"), + sqlalchemy.literal("missing").label("tag_name"), + sqlalchemy.literal(origin).label("origin"), + sqlalchemy.literal(utcnow()).label("added_at"), + ) + .where(AssetInfo.asset_id == asset_id) + .where( + sqlalchemy.not_( + sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) + ) + ) + ) + session.execute( + sqlite.insert(AssetInfoTag) + .from_select( + ["asset_info_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + ) + +def remove_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, +) -> None: + session.execute( + sqlalchemy.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), + AssetInfoTag.tag_name == "missing", + ) + ) diff --git a/app/assets/hashing.py b/app/assets/hashing.py new file mode 100644 index 000000000..4b72084b9 --- /dev/null +++ b/app/assets/hashing.py @@ -0,0 +1,75 @@ +from blake3 import blake3 +from typing import IO +import os +import asyncio + + +DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB + +# NOTE: this allows hashing different representations of a file-like object +def blake3_hash( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """ + Returns a BLAKE3 hex digest for ``fp``, which may be: + - a filename (str/bytes) or PathLike + - an open binary file object + If ``fp`` is a file object, it must be opened in **binary** mode and support + ``read``, ``seek``, and ``tell``. The function will seek to the start before + reading and will attempt to restore the original position afterward. + """ + # duck typing to check if input is a file-like object + if hasattr(fp, "read"): + return _hash_file_obj(fp, chunk_size) + + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj(f, chunk_size) + + +async def blake3_hash_async( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, +) -> str: + """Async wrapper for ``blake3_hash_sync``. + Uses a worker thread so the event loop remains responsive. + """ + # If it is a path, open inside the worker thread to keep I/O off the loop. + if hasattr(fp, "read"): + return await asyncio.to_thread(blake3_hash, fp, chunk_size) + + def _worker() -> str: + with open(os.fspath(fp), "rb") as f: + return _hash_file_obj(f, chunk_size) + + return await asyncio.to_thread(_worker) + + +def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str: + """ + Hash an already-open binary file object by streaming in chunks. + - Seeks to the beginning before reading (if supported). + - Restores the original position afterward (if tell/seek are supported). + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + # in case file object is already open and not at the beginning, track so can be restored after hashing + orig_pos = file_obj.tell() + + try: + # seek to the beginning before reading + if orig_pos != 0: + file_obj.seek(0) + + h = blake3() + while True: + chunk = file_obj.read(chunk_size) + if not chunk: + break + h.update(chunk) + return h.hexdigest() + finally: + # restore original position in file object, if needed + if orig_pos != 0: + file_obj.seek(orig_pos) diff --git a/app/assets/helpers.py b/app/assets/helpers.py new file mode 100644 index 000000000..5030b123a --- /dev/null +++ b/app/assets/helpers.py @@ -0,0 +1,312 @@ +import contextlib +import os +from decimal import Decimal +from aiohttp import web +from datetime import datetime, timezone +from pathlib import Path +from typing import Literal, Any + +import folder_paths + + +RootType = Literal["models", "input", "output"] +ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") + +def get_query_dict(request: web.Request) -> dict[str, Any]: + """ + Gets a dictionary of query parameters from the request. + + 'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic. + """ + query_dict = { + key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key) + for key in request.query.keys() + } + return query_dict + +def list_tree(base_dir: str) -> list[str]: + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for name in filenames: + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out + +def prefixes_for_root(root: RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + +def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char itself in a LIKE prefix. + Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + """ + s = s.replace(escape, escape + escape) # escape the escape char first + s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards + return s, escape + +def fast_asset_file_check( + *, + mtime_db: int | None, + size_db: int | None, + stat_result: os.stat_result, +) -> bool: + if mtime_db is None: + return False + actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) + if int(mtime_db) != int(actual_mtime_ns): + return False + sz = int(size_db or 0) + if sz > 0: + return int(stat_result.st_size) == sz + return True + +def utcnow() -> datetime: + """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" + return datetime.now(timezone.utc).replace(tzinfo=None) + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. + + We trust `folder_paths.folder_names_and_paths` and include a category if + *any* of its base paths lies under the Comfy `models_dir`. + """ + targets: list[tuple[str, list[str]]] = [] + models_root = os.path.abspath(folder_paths.models_dir) + for name, values in folder_paths.folder_names_and_paths.items(): + paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI + if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): + targets.append((name, paths)) + return targets + +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + root = tags[0] + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + else: + base_dir = os.path.abspath( + folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() + ) + raw_subdirs = tags[1:] + for i in raw_subdirs: + if i in (".", ".."): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + +def ensure_within_base(candidate: str, base: str) -> None: + cand_abs = os.path.abspath(candidate) + base_abs = os.path.abspath(base) + try: + if os.path.commonpath([cand_abs, base_abs]) != base_abs: + raise ValueError("destination escapes base directory") + except Exception: + raise ValueError("invalid destination path") + +def compute_relative_filename(file_path: str) -> str | None: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + NOTE: this is a temporary helper, used only for initializing metadata["filename"] field. + """ + try: + root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path) + except ValueError: + return None + + p = Path(rel_path) + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts + +def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: + """Given an absolute or relative file path, determine which root category the path belongs to: + - 'input' if the file resides under `folder_paths.get_input_directory()` + - 'output' if the file resides under `folder_paths.get_output_directory()` + - 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()` + + Returns: + (root_category, relative_path_inside_that_root) + For 'models', the relative path is prefixed with the category name: + e.g. ('models', 'vae/test/sub/ae.safetensors') + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + fp_abs = os.path.abspath(file_path) + + def _is_within(child: str, parent: str) -> bool: + try: + return os.path.commonpath([child, parent]) == parent + except Exception: + return False + + def _rel(child: str, parent: str) -> str: + return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _is_within(fp_abs, input_base): + return "input", _rel(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _is_within(fp_abs, output_base): + return "output", _rel(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _rel(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return a tuple (name, tags) derived from a filesystem path. + + Semantics: + - Root category is determined by `get_relative_to_root_category_path_of_asset`. + - The returned `name` is the base filename with extension from the relative path. + - The returned `tags` are: + [root_category] + parent folders of the relative path (in order) + For 'models', this means: + file '/.../ModelsDir/vae/test_tag/ae.safetensors' + -> root_category='models', some_path='vae/test_tag/ae.safetensors' + -> name='ae.safetensors', tags=['models', 'vae', 'test_tag'] + + Raises: + ValueError: if the path does not belong to input, output, or configured model bases. + """ + root_category, some_path = get_relative_to_root_category_path_of_asset(file_path) + p = Path(some_path) + parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) + +def normalize_tags(tags: list[str] | None) -> list[str]: + """ + Normalize a list of tags by: + - Stripping whitespace and converting to lowercase. + - Removing duplicates. + """ + return [t.strip().lower() for t in (tags or []) if (t or "").strip()] + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: + continue + abs_path = os.path.abspath(abs_path) + allowed = False + for b in bases: + base_abs = os.path.abspath(b) + with contextlib.suppress(Exception): + if os.path.commonpath([abs_path, base_abs]) == base_abs: + allowed = True + break + if allowed: + out.append(abs_path) + return out + +def is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + +def project_kv(key: str, value): + """ + Turn a metadata key/value into typed projection rows. + Returns list[dict] with keys: + key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) + """ + rows: list[dict] = [] + + def _null_row(ordinal: int) -> dict: + return { + "key": key, "ordinal": ordinal, + "val_str": None, "val_num": None, "val_bool": None, "val_json": None + } + + if value is None: + rows.append(_null_row(0)) + return rows + + if is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + if isinstance(value, list): + if all(is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append(_null_row(i)) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/assets/manager.py b/app/assets/manager.py new file mode 100644 index 000000000..a68c8c8ae --- /dev/null +++ b/app/assets/manager.py @@ -0,0 +1,516 @@ +import os +import mimetypes +import contextlib +from typing import Sequence + +from app.database.db import create_session +from app.assets.api import schemas_out, schemas_in +from app.assets.database.queries import ( + asset_exists_by_hash, + asset_info_exists_for_asset_id, + get_asset_by_hash, + get_asset_info_by_id, + fetch_asset_info_asset_and_tags, + fetch_asset_info_and_asset, + create_asset_info_for_existing_asset, + touch_asset_info_by_id, + update_asset_info_full, + delete_asset_info_by_id, + list_cache_states_by_asset_id, + list_asset_infos_page, + list_tags_with_usage, + get_asset_tags, + add_tags_to_asset_info, + remove_tags_from_asset_info, + pick_best_live_path, + ingest_fs_asset, + set_asset_info_preview, +) +from app.assets.helpers import resolve_destination_from_tags, ensure_within_base +from app.assets.database.models import Asset + + +def _safe_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" + + +def _get_size_mtime_ns(path: str) -> tuple[int, int]: + st = os.stat(path, follow_symlinks=True) + return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + +def _safe_filename(name: str | None, fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + if n: + return n + return fallback + + +def asset_exists(*, asset_hash: str) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + with create_session() as session: + return asset_exists_by_hash(session, asset_hash=asset_hash) + + +def list_assets( + *, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", + owner_id: str = "", +) -> schemas_out.AssetsList: + sort = _safe_sort_field(sort) + order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() + + with create_session() as session: + infos, tag_map, total = list_asset_infos_page( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + summaries: list[schemas_out.AssetSummary] = [] + for info in infos: + asset = info.asset + tags = tag_map.get(info.id, []) + summaries.append( + schemas_out.AssetSummary( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + created_at=info.created_at, + updated_at=info.updated_at, + last_access_time=info.last_access_time, + ) + ) + + return schemas_out.AssetsList( + assets=summaries, + total=total, + has_more=(offset + len(summaries)) < total, + ) + + +def get_asset( + *, + asset_info_id: str, + owner_id: str = "", +) -> schemas_out.AssetDetail: + with create_session() as session: + res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise ValueError(f"AssetInfo {asset_info_id} not found") + info, asset, tag_names = res + preview_id = info.preview_id + + return schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + + +def resolve_asset_content_for_download( + *, + asset_info_id: str, + owner_id: str = "", +) -> tuple[str, str, str]: + with create_session() as session: + pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not pair: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info, asset = pair + states = list_cache_states_by_asset_id(session, asset_id=asset.id) + abs_path = pick_best_live_path(states) + if not abs_path: + raise FileNotFoundError + + touch_asset_info_by_id(session, asset_info_id=asset_info_id) + session.commit() + + ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" + download_name = info.name or os.path.basename(abs_path) + return abs_path, ctype, download_name + + +def upload_asset_from_temp_path( + spec: schemas_in.UploadAssetSpec, + *, + temp_path: str, + client_filename: str | None = None, + owner_id: str = "", + expected_asset_hash: str | None = None, +) -> schemas_out.AssetCreated: + """ + Create new asset or update existing asset from a temporary file path. + """ + try: + # NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment + import app.assets.hashing as hashing + digest = hashing.blake3_hash(temp_path) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): + raise ValueError("HASH_MISMATCH") + + with create_session() as session: + existing = get_asset_by_hash(session, asset_hash=asset_hash) + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) + info = create_asset_info_for_existing_asset( + session, + asset_hash=asset_hash, + name=display_name, + user_metadata=spec.user_metadata or {}, + tags=spec.tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + tag_names = get_asset_tags(session, asset_info_id=info.id) + session.commit() + + return schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=existing.hash, + size=int(existing.size_bytes) if existing.size_bytes is not None else None, + mime_type=existing.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + + base_dir, subdirs = resolve_destination_from_tags(spec.tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + src_for_ext = (client_filename or spec.name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) + ensure_within_base(dest_abs, base_dir) + + content_type = ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) + + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + try: + size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + with create_session() as session: + result = ingest_fs_asset( + session, + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest), + owner_id=owner_id, + preview_id=None, + user_metadata=spec.user_metadata or {}, + tags=spec.tags, + tag_origin="manual", + require_existing_tags=False, + ) + info_id = result["asset_info_id"] + if not info_id: + raise RuntimeError("failed to create asset metadata") + + pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + info, asset = pair + tag_names = get_asset_tags(session, asset_info_id=info.id) + created_result = schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=asset.hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=result["asset_created"], + ) + session.commit() + + return created_result + + +def update_asset( + *, + asset_info_id: str, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> schemas_out.AssetUpdated: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + info = update_asset_info_full( + session, + asset_info_id=asset_info_id, + name=name, + tags=tags, + user_metadata=user_metadata, + tag_origin="manual", + asset_info_row=info_row, + ) + + tag_names = get_asset_tags(session, asset_info_id=asset_info_id) + result = schemas_out.AssetUpdated( + id=info.id, + name=info.name, + asset_hash=info.asset.hash if info.asset else None, + tags=tag_names, + user_metadata=info.user_metadata or {}, + updated_at=info.updated_at, + ) + session.commit() + + return result + + +def set_asset_preview( + *, + asset_info_id: str, + preview_asset_id: str | None = None, + owner_id: str = "", +) -> schemas_out.AssetDetail: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + set_asset_info_preview( + session, + asset_info_id=asset_info_id, + preview_asset_id=preview_asset_id, + ) + + res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise RuntimeError("State changed during preview update") + info, asset, tags = res + result = schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + session.commit() + + return result + + +def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + asset_id = info_row.asset_id if info_row else None + deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not deleted: + session.commit() + return False + + if not delete_content_if_orphan or not asset_id: + session.commit() + return True + + still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id) + if still_exists: + session.commit() + return True + + states = list_cache_states_by_asset_id(session, asset_id=asset_id) + file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] + + asset_row = session.get(Asset, asset_id) + if asset_row is not None: + session.delete(asset_row) + + session.commit() + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + return True + + +def create_asset_from_hash( + *, + hash_str: str, + name: str, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> schemas_out.AssetCreated | None: + canonical = hash_str.strip().lower() + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=canonical) + if not asset: + return None + + info = create_asset_info_for_existing_asset( + session, + asset_hash=canonical, + name=_safe_filename(name, fallback=canonical.split(":", 1)[1]), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + tag_names = get_asset_tags(session, asset_info_id=info.id) + result = schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=asset.hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + session.commit() + + return result + + +def add_tags_to_asset( + *, + asset_info_id: str, + tags: list[str], + origin: str = "manual", + owner_id: str = "", +) -> schemas_out.TagsAdd: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + data = add_tags_to_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=origin, + create_if_missing=True, + asset_info_row=info_row, + ) + session.commit() + return schemas_out.TagsAdd(**data) + + +def remove_tags_from_asset( + *, + asset_info_id: str, + tags: list[str], + owner_id: str = "", +) -> schemas_out.TagsRemove: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + data = remove_tags_from_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + ) + session.commit() + return schemas_out.TagsRemove(**data) + + +def list_tags( + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, + owner_id: str = "", +) -> schemas_out.TagsList: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + with create_session() as session: + rows, total = list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + owner_id=owner_id, + ) + + tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] + return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total) diff --git a/app/assets/scanner.py b/app/assets/scanner.py new file mode 100644 index 000000000..0172a5c2f --- /dev/null +++ b/app/assets/scanner.py @@ -0,0 +1,263 @@ +import contextlib +import time +import logging +import os +import sqlalchemy + +import folder_paths +from app.database.db import create_session, dependencies_available +from app.assets.helpers import ( + collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path, + list_tree,prefixes_for_root, escape_like_prefix, + RootType +) +from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id +from app.assets.database.bulk_ops import seed_from_paths_batch +from app.assets.database.models import Asset, AssetCacheState, AssetInfo + + +def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None: + """ + Scan the given roots and seed the assets into the database. + """ + if not dependencies_available(): + if enable_logging: + logging.warning("Database dependencies not available, skipping assets scan") + return + t_start = time.perf_counter() + created = 0 + skipped_existing = 0 + orphans_pruned = 0 + paths: list[str] = [] + try: + existing_paths: set[str] = set() + for r in roots: + try: + survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True) + if survivors: + existing_paths.update(survivors) + except Exception as e: + logging.exception("fast DB scan failed for %s: %s", r, e) + + try: + orphans_pruned = _prune_orphaned_assets(roots) + except Exception as e: + logging.exception("orphan pruning failed: %s", e) + + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_tree(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_tree(folder_paths.get_output_directory())) + + specs: list[dict] = [] + tag_pool: set[str] = set() + for p in paths: + abs_p = os.path.abspath(p) + if abs_p in existing_paths: + skipped_existing += 1 + continue + try: + stat_p = os.stat(abs_p, follow_symlinks=False) + except OSError: + continue + # skip empty files + if not stat_p.st_size: + continue + name, tags = get_name_and_tags_from_asset_path(abs_p) + specs.append( + { + "abs_path": abs_p, + "size_bytes": stat_p.st_size, + "mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)), + "info_name": name, + "tags": tags, + "fname": compute_relative_filename(abs_p), + } + ) + for t in tags: + tag_pool.add(t) + # if no file specs, nothing to do + if not specs: + return + with create_session() as sess: + if tag_pool: + ensure_tags_exist(sess, tag_pool, tag_type="user") + + result = seed_from_paths_batch(sess, specs=specs, owner_id="") + created += result["inserted_infos"] + sess.commit() + finally: + if enable_logging: + logging.info( + "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)", + roots, + time.perf_counter() - t_start, + created, + skipped_existing, + orphans_pruned, + len(paths), + ) + + +def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int: + """Prune cache states outside configured prefixes, then delete orphaned seed assets.""" + all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)] + if not all_prefixes: + return 0 + + def make_prefix_condition(prefix: str): + base = prefix if prefix.endswith(os.sep) else prefix + os.sep + escaped, esc = escape_like_prefix(base) + return AssetCacheState.file_path.like(escaped + "%", escape=esc) + + matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes]) + + orphan_subq = ( + sqlalchemy.select(Asset.id) + .outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id) + .where(Asset.hash.is_(None), AssetCacheState.id.is_(None)) + ).scalar_subquery() + + with create_session() as sess: + sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix)) + sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq))) + result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq))) + sess.commit() + return result.rowcount + + +def _fast_db_consistency_pass( + root: RootType, + *, + collect_existing_paths: bool = False, + update_missing_tags: bool = False, +) -> set[str] | None: + """Fast DB+FS pass for a root: + - Toggle needs_verify per state using fast check + - For hashed assets with at least one fast-ok state in this root: delete stale missing states + - For seed assets with all states missing: delete Asset and its AssetInfos + - Optionally add/remove 'missing' tags based on fast-ok in this root + - Optionally return surviving absolute paths + """ + prefixes = prefixes_for_root(root) + if not prefixes: + return set() if collect_existing_paths else None + + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_like_prefix(base) + conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc)) + + with create_session() as sess: + rows = ( + sess.execute( + sqlalchemy.select( + AssetCacheState.id, + AssetCacheState.file_path, + AssetCacheState.mtime_ns, + AssetCacheState.needs_verify, + AssetCacheState.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetCacheState.asset_id) + .where(sqlalchemy.or_(*conds)) + .order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc()) + ) + ).all() + + by_asset: dict[str, dict] = {} + for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: + acc = by_asset.get(aid) + if acc is None: + acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} + by_asset[aid] = acc + + fast_ok = False + try: + exists = True + fast_ok = fast_asset_file_check( + mtime_db=mtime_db, + size_db=acc["size_db"], + stat_result=os.stat(fp, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except OSError: + exists = False + + acc["states"].append({ + "sid": sid, + "fp": fp, + "exists": exists, + "fast_ok": fast_ok, + "needs_verify": bool(needs_verify), + }) + + to_set_verify: list[int] = [] + to_clear_verify: list[int] = [] + stale_state_ids: list[int] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + states = acc["states"] + any_fast_ok = any(s["fast_ok"] for s in states) + all_missing = all(not s["exists"] for s in states) + + for s in states: + if not s["exists"]: + continue + if s["fast_ok"] and s["needs_verify"]: + to_clear_verify.append(s["sid"]) + if not s["fast_ok"] and not s["needs_verify"]: + to_set_verify.append(s["sid"]) + + if a_hash is None: + if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists + sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid)) + asset = sess.get(Asset, aid) + if asset: + sess.delete(asset) + else: + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + continue + + if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records + for s in states: + if not s["exists"]: + stale_state_ids.append(s["sid"]) + if update_missing_tags: + with contextlib.suppress(Exception): + remove_missing_tag_for_asset_id(sess, asset_id=aid) + elif update_missing_tags: + with contextlib.suppress(Exception): + add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") + + for s in states: + if s["exists"]: + survivors.add(os.path.abspath(s["fp"])) + + if stale_state_ids: + sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids))) + if to_set_verify: + sess.execute( + sqlalchemy.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_set_verify)) + .values(needs_verify=True) + ) + if to_clear_verify: + sess.execute( + sqlalchemy.update(AssetCacheState) + .where(AssetCacheState.id.in_(to_clear_verify)) + .values(needs_verify=False) + ) + sess.commit() + return survivors if collect_existing_paths else None diff --git a/app/database/models.py b/app/database/models.py index 6facfb8f2..e7572677a 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,14 +1,21 @@ -from sqlalchemy.orm import declarative_base +from typing import Any +from datetime import datetime +from sqlalchemy.orm import DeclarativeBase -Base = declarative_base() +class Base(DeclarativeBase): + pass - -def to_dict(obj): +def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: fields = obj.__table__.columns.keys() - return { - field: (val.to_dict() if hasattr(val, "to_dict") else val) - for field in fields - if (val := getattr(obj, field)) - } + out: dict[str, Any] = {} + for field in fields: + val = getattr(obj, field) + if val is None and not include_none: + continue + if isinstance(val, datetime): + out[field] = val.isoformat() + else: + out[field] = val + return out # TODO: Define models here diff --git a/app/subgraph_manager.py b/app/subgraph_manager.py index dbe404541..6a8f586a4 100644 --- a/app/subgraph_manager.py +++ b/app/subgraph_manager.py @@ -10,6 +10,7 @@ import hashlib class Source: custom_node = "custom_node" + templates = "templates" class SubgraphEntry(TypedDict): source: str @@ -38,6 +39,18 @@ class CustomNodeSubgraphEntryInfo(TypedDict): class SubgraphManager: def __init__(self): self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None + self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None + + def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]: + """Create a subgraph entry from a file path. Expects normalized path (forward slashes).""" + entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() + entry: SubgraphEntry = { + "source": source, + "name": os.path.splitext(os.path.basename(file))[0], + "path": file, + "info": {"node_pack": node_pack}, + } + return entry_id, entry async def load_entry_data(self, entry: SubgraphEntry): with open(entry['path'], 'r') as f: @@ -60,53 +73,60 @@ class SubgraphManager: return entries async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): - # if not forced to reload and cached, return cache + """Load subgraphs from custom nodes.""" if not force_reload and self.cached_custom_node_subgraphs is not None: return self.cached_custom_node_subgraphs - # Load subgraphs from custom nodes - subfolder = "subgraphs" - subgraphs_dict: dict[SubgraphEntry] = {} + subgraphs_dict: dict[SubgraphEntry] = {} for folder in folder_paths.get_folder_paths("custom_nodes"): - pattern = os.path.join(folder, f"*/{subfolder}/*.json") - matched_files = glob.glob(pattern) - for file in matched_files: - # replace backslashes with forward slashes + pattern = os.path.join(folder, "*/subgraphs/*.json") + for file in glob.glob(pattern): file = file.replace('\\', '/') - info: CustomNodeSubgraphEntryInfo = { - "node_pack": "custom_nodes." + file.split('/')[-3] - } - source = Source.custom_node - # hash source + path to make sure id will be as unique as possible, but - # reproducible across backend reloads - id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() - entry: SubgraphEntry = { - "source": Source.custom_node, - "name": os.path.splitext(os.path.basename(file))[0], - "path": file, - "info": info, - } - subgraphs_dict[id] = entry + node_pack = "custom_nodes." + file.split('/')[-3] + entry_id, entry = self._create_entry(file, Source.custom_node, node_pack) + subgraphs_dict[entry_id] = entry + self.cached_custom_node_subgraphs = subgraphs_dict return subgraphs_dict - async def get_custom_node_subgraph(self, id: str, loadedModules): - subgraphs = await self.get_custom_node_subgraphs(loadedModules) - entry: SubgraphEntry = subgraphs.get(id, None) - if entry is not None and entry.get('data', None) is None: + async def get_blueprint_subgraphs(self, force_reload=False): + """Load subgraphs from the blueprints directory.""" + if not force_reload and self.cached_blueprint_subgraphs is not None: + return self.cached_blueprint_subgraphs + + subgraphs_dict: dict[SubgraphEntry] = {} + blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints') + + if os.path.exists(blueprints_dir): + for file in glob.glob(os.path.join(blueprints_dir, "*.json")): + file = file.replace('\\', '/') + entry_id, entry = self._create_entry(file, Source.templates, "comfyui") + subgraphs_dict[entry_id] = entry + + self.cached_blueprint_subgraphs = subgraphs_dict + return subgraphs_dict + + async def get_all_subgraphs(self, loadedModules, force_reload=False): + """Get all subgraphs from all sources (custom nodes and blueprints).""" + custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload) + blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload) + return {**custom_node_subgraphs, **blueprint_subgraphs} + + async def get_subgraph(self, id: str, loadedModules): + """Get a specific subgraph by ID from any source.""" + entry = (await self.get_all_subgraphs(loadedModules)).get(id) + if entry is not None and entry.get('data') is None: await self.load_entry_data(entry) return entry def add_routes(self, routes, loadedModules): @routes.get("/global_subgraphs") async def get_global_subgraphs(request): - subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules) - # NOTE: we may want to include other sources of global subgraphs such as templates in the future; - # that's the reasoning for the current implementation + subgraphs_dict = await self.get_all_subgraphs(loadedModules) return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) @routes.get("/global_subgraphs/{id}") async def get_global_subgraph(request): id = request.match_info.get("id", None) - subgraph = await self.get_custom_node_subgraph(id, loadedModules) + subgraph = await self.get_subgraph(id, loadedModules) return web.json_response(await self.sanitize_entry(subgraph)) diff --git a/blueprints/put_blueprints_here b/blueprints/put_blueprints_here new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95..16998af94 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -25,11 +25,11 @@ class AudioEncoderModel(): elif model_type == "whisper3": self.model = WhisperLargeV3(**model_config) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/cli_args.py b/comfy/cli_args.py index dae9a895d..63daca861 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -159,6 +159,7 @@ class PerformanceFeature(enum.Enum): Fp8MatrixMultiplication = "fp8_matrix_mult" CublasOps = "cublas_ops" AutoTune = "autotune" + DynamicVRAM = "dynamic_vram" parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) @@ -231,6 +232,7 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") if comfy.options.args_parsing: args = parser.parse_args() @@ -256,3 +258,6 @@ elif args.fast == []: # '--fast' is provided with a list of performance features, use that list else: args.fast = set(args.fast) + +def enables_dynamic_vram(): + return PerformanceFeature.DynamicVRAM in args.fast and not args.highvram and not args.gpu_only diff --git a/comfy/clip_model.py b/comfy/clip_model.py index e88872728..d7d3f994c 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -1,6 +1,7 @@ import torch from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.ops +import math def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): image = image[:, :, :, :3] if image.shape[3] > 3 else image @@ -21,6 +22,39 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s image = torch.clip((255. * image), 0, 255).round() / 255.0 return (image - mean.view([3,1,1])) / std.view([3,1,1]) +def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5): + def scale_dim(size, scale): + scaled = math.ceil(size * scale / patch_size) * patch_size + return max(patch_size, int(scaled)) + + # Binary search for optimal scale + lo, hi = eps / 10, 100.0 + while hi - lo >= eps: + mid = (lo + hi) / 2 + h, w = scale_dim(oh, mid), scale_dim(ow, mid) + if (h // patch_size) * (w // patch_size) <= max_num_patches: + lo = mid + else: + hi = mid + + return scale_dim(oh, lo), scale_dim(ow, lo) + +def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True): + if size > 0: + return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop) + + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) + + b, c, h, w = image.shape + h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches) + + image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True) + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1]) + class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): super().__init__() @@ -175,6 +209,27 @@ class CLIPTextModel(torch.nn.Module): out = self.text_projection(x[2]) return (x[0], x[1], out, x[2]) +def siglip2_pos_embed(embed_weight, embeds, orig_shape): + embed_weight_len = round(embed_weight.shape[0] ** 0.5) + embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len) + embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True) + embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1) + return embeds + embed_weight + +class Siglip2Embeddings(torch.nn.Module): + def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None): + super().__init__() + self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device) + self.patch_size = patch_size + + def forward(self, pixel_values): + b, c, h, w = pixel_values.shape + img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c) + img = img.permute(0, 1, 3, 2, 4, 5) + img = img.reshape(b, img.shape[1] * img.shape[2], -1) + img = self.patch_embedding(img) + return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size)) class CLIPVisionEmbeddings(torch.nn.Module): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None): @@ -218,8 +273,11 @@ class CLIPVision(torch.nn.Module): intermediate_activation = config_dict["hidden_act"] model_type = config_dict["model_type"] - self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) - if model_type == "siglip_vision_model": + if model_type in ["siglip2_vision_model"]: + self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations) + else: + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) + if model_type in ["siglip_vision_model", "siglip2_vision_model"]: self.pre_layrnorm = lambda a: a self.output_layernorm = True else: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index d5fc53497..1691fca81 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -21,6 +21,7 @@ clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from br IMAGE_ENCODERS = { "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, + "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, } @@ -32,9 +33,10 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) - model_type = config.get("model_type", "clip_vision_model") - model_class = IMAGE_ENCODERS.get(model_type) - if model_type == "siglip_vision_model": + self.model_type = config.get("model_type", "clip_vision_model") + self.config = config.copy() + model_class = IMAGE_ENCODERS.get(self.model_type) + if self.model_type == "siglip_vision_model": self.return_all_hidden_states = True else: self.return_all_hidden_states = False @@ -45,22 +47,26 @@ class ClipVisionModel(): self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + if self.model_type == "siglip2_vision_model": + pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float() + else: + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) + outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0] if self.return_all_hidden_states: all_hs = out[1].to(comfy.model_management.intermediate_device()) outputs["penultimate_hidden_states"] = all_hs[:, -2] @@ -107,10 +113,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: - if embed_shape == 729: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") - elif embed_shape == 1024: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") + patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape + if len(patch_embedding_shape) == 2: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json") + else: + if embed_shape == 729: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif embed_shape == 1024: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") elif embed_shape == 577: if "multi_modal_projector.linear_1.bias" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") diff --git a/comfy/clip_vision_siglip2_base_naflex.json b/comfy/clip_vision_siglip2_base_naflex.json new file mode 100644 index 000000000..6f6b99bd6 --- /dev/null +++ b/comfy/clip_vision_siglip2_base_naflex.json @@ -0,0 +1,14 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": -1, + "intermediate_size": 4304, + "model_type": "siglip2_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 16, + "num_patches": 256, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 071b98332..0194b7d70 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -236,6 +236,8 @@ class ComfyNodeABC(ABC): """Flags a node as experimental, informing users that it may change or not work as expected.""" DEPRECATED: bool """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" + DEV_ONLY: bool + """Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled.""" API_NODE: Optional[bool] """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52..9e1e704e0 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -203,7 +203,7 @@ class ControlNet(ControlBase): self.control_model = control_model self.load_device = load_device if control_model is not None: - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling diff --git a/comfy/float.py b/comfy/float.py index 521316fd2..88c47cd80 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -65,3 +65,147 @@ def stochastic_rounding(value, dtype, seed=0): return output return value.to(dtype=dtype) + + +# TODO: improve this? +def stochastic_float_to_fp4_e2m1(x, generator): + orig_shape = x.shape + sign = torch.signbit(x).to(torch.uint8) + + exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3) + x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 + + x = x.abs() + exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3) + + mantissa = torch.where( + exp > 0, + (x / (2.0 ** (exp - 1)) - 1.0) * 2.0, + (x * 2.0), + out=x + ).round().to(torch.uint8) + del x + + exp = exp.to(torch.uint8) + + fp4 = (sign << 3) | (exp << 1) | mantissa + del sign, exp, mantissa + + fp4_flat = fp4.view(-1) + packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] + return packed.reshape(list(orig_shape)[:-1] + [-1]) + + +def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + + def ceil_div(a, b): + return (a + b - 1) // b + + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + if flatten: + return rearranged.flatten() + + return rearranged.reshape(padded_rows, padded_cols) + + +def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator): + F4_E2M1_MAX = 6.0 + F8_E4M3_MAX = 448.0 + + orig_shape = x.shape + + block_size = 16 + + x = x.reshape(orig_shape[0], -1, block_size) + scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) + + x = x.view(orig_shape).nan_to_num() + data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) + return data_lp, scaled_block_scales_fp8 + + +def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator) + return x, to_blocked(blocked_scaled, flatten=False) + + +def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096): + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + orig_shape = x.shape + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + # Note: We update orig_shape because the output tensor logic below assumes x.shape matches + # what we want to produce. If we pad here, we want the padded output. + orig_shape = x.shape + + orig_shape = list(orig_shape) + + output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device) + output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device) + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + num_slices = max(1, (x.numel() / block_size)) + slice_size = max(1, (round(x.shape[0] / num_slices))) + + for i in range(0, x.shape[0], slice_size): + fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator) + output_fp4[i:i + slice_size].copy_(fp4) + output_block[i:i + slice_size].copy_(block) + + return output_fp4, to_blocked(output_block, flatten=False) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 0949dee44..c0c51d51a 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1,11 +1,12 @@ import math +import time from functools import partial from scipy import integrate import torch from torch import nn import torchsde -from tqdm.auto import trange, tqdm +from tqdm.auto import trange as trange_, tqdm from . import utils from . import deis @@ -13,6 +14,36 @@ from . import sa_solver import comfy.model_patcher import comfy.model_sampling +import comfy.memory_management + + +def trange(*args, **kwargs): + if comfy.memory_management.aimdo_allocator is None: + return trange_(*args, **kwargs) + + pbar = trange_(*args, **kwargs, smoothing=1.0) + pbar._i = 0 + pbar.set_postfix_str(" Model Initializing ... ") + + _update = pbar.update + + def warmup_update(n=1): + pbar._i += 1 + if pbar._i == 1: + pbar.i1_time = time.time() + pbar.set_postfix_str(" Model Initialization complete! ") + elif pbar._i == 2: + #bring forward the effective start time based the the diff between first and second iteration + #to attempt to remove load overhead from the final step rate estimate. + pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) + pbar.set_postfix_str("") + + _update(n) + + pbar.update = warmup_update + return pbar + + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f1ca0151e..4b3a3798c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -8,6 +8,7 @@ class LatentFormat: latent_rgb_factors_bias = None latent_rgb_factors_reshape = None taesd_decoder_name = None + spacial_downscale_ratio = 8 def process_in(self, latent): return latent * self.scale_factor @@ -80,6 +81,7 @@ class SD_X4(LatentFormat): class SC_Prior(LatentFormat): latent_channels = 16 + spacial_downscale_ratio = 42 def __init__(self): self.scale_factor = 1.0 self.latent_rgb_factors = [ @@ -102,6 +104,7 @@ class SC_Prior(LatentFormat): ] class SC_B(LatentFormat): + spacial_downscale_ratio = 4 def __init__(self): self.scale_factor = 1.0 / 0.43 self.latent_rgb_factors = [ @@ -181,6 +184,7 @@ class Flux(SD3): class Flux2(LatentFormat): latent_channels = 128 + spacial_downscale_ratio = 16 def __init__(self): self.latent_rgb_factors =[ @@ -272,6 +276,7 @@ class Mochi(LatentFormat): class LTXV(LatentFormat): latent_channels = 128 latent_dimensions = 3 + spacial_downscale_ratio = 32 def __init__(self): self.latent_rgb_factors = [ @@ -407,6 +412,11 @@ class LTXV(LatentFormat): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] +class LTXAV(LTXV): + def __init__(self): + self.latent_rgb_factors = None + self.latent_rgb_factors_bias = None + class HunyuanVideo(LatentFormat): latent_channels = 16 latent_dimensions = 3 @@ -510,6 +520,7 @@ class Wan21(LatentFormat): class Wan22(Wan21): latent_channels = 48 latent_dimensions = 3 + spacial_downscale_ratio = 16 latent_rgb_factors = [ [ 0.0119, 0.0103, 0.0046], @@ -587,6 +598,7 @@ class Wan22(Wan21): class HunyuanImage21(LatentFormat): latent_channels = 64 latent_dimensions = 2 + spacial_downscale_ratio = 32 scale_factor = 0.75289 latent_rgb_factors = [ @@ -720,6 +732,7 @@ class HunyuanVideo15(LatentFormat): latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644] latent_channels = 32 latent_dimensions = 3 + spacial_downscale_ratio = 16 scale_factor = 1.03682 taesd_decoder_name = "lighttaehy1_5" @@ -744,6 +757,7 @@ class ACEAudio(LatentFormat): class ChromaRadiance(LatentFormat): latent_channels = 3 + spacial_downscale_ratio = 1 def __init__(self): self.latent_rgb_factors = [ diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py new file mode 100644 index 000000000..2e6ed58fa --- /dev/null +++ b/comfy/ldm/anima/model.py @@ -0,0 +1,202 @@ +from comfy.ldm.cosmos.predict2 import MiniTrainDIT +import torch +from torch import nn +import torch.nn.functional as F + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim): + super().__init__() + self.rope_theta = 10000 + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Attention(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None): + super().__init__() + + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + + self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): + context = x if context is None else context + input_shape = x.shape[:-1] + q_shape = (*input_shape, self.n_heads, self.head_dim) + context_shape = context.shape[:-1] + kv_shape = (*context_shape, self.n_heads, self.head_dim) + + query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) + value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) + + if position_embeddings is not None: + assert position_embeddings_context is not None + cos, sin = position_embeddings + query_states = apply_rotary_pos_emb(query_states, cos, sin) + cos, sin = position_embeddings_context + key_states = apply_rotary_pos_emb(key_states, cos, sin) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) + + attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + def init_weights(self): + torch.nn.init.zeros_(self.o_proj.weight) + + +class TransformerBlock(nn.Module): + def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None): + super().__init__() + self.use_self_attn = use_self_attn + + if self.use_self_attn: + self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention( + query_dim=model_dim, + context_dim=model_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + query_dim=model_dim, + context_dim=source_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype), + nn.GELU(), + operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype) + ) + + def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None): + if self.use_self_attn: + normed = self.norm_self_attn(x) + attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings) + x = x + attn_out + + normed = self.norm_cross_attn(x) + attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + x = x + attn_out + + x = x + self.mlp(self.norm_mlp(x)) + return x + + def init_weights(self): + torch.nn.init.zeros_(self.mlp[2].weight) + self.cross_attn.init_weights() + + +class LLMAdapter(nn.Module): + def __init__( + self, + source_dim=1024, + target_dim=1024, + model_dim=1024, + num_layers=6, + num_heads=16, + use_self_attn=True, + layer_norm=False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype) + if model_dim != target_dim: + self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype) + else: + self.in_proj = nn.Identity() + self.rotary_emb = RotaryEmbedding(model_dim//num_heads) + self.blocks = nn.ModuleList([ + TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) + ]) + self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype) + self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype) + + def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): + if target_attention_mask is not None: + target_attention_mask = target_attention_mask.to(torch.bool) + if target_attention_mask.ndim == 2: + target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) + + if source_attention_mask is not None: + source_attention_mask = source_attention_mask.to(torch.bool) + if source_attention_mask.ndim == 2: + source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) + + x = self.in_proj(self.embed(target_input_ids)) + context = source_hidden_states + position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) + position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) + position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_context = self.rotary_emb(x, position_ids_context) + for block in self.blocks: + x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + return self.norm(self.out_proj(x)) + + +class Anima(MiniTrainDIT): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) + + def preprocess_text_embeds(self, text_embeds, text_ids): + if text_ids is not None: + return self.llm_adapter(text_embeds, text_ids) + else: + return text_embeds diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 07a4fc79f..c270e6333 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -13,6 +13,7 @@ from torchvision import transforms import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention +import comfy.ldm.common_dit def apply_rotary_pos_emb( t: torch.Tensor, @@ -835,6 +836,8 @@ class MiniTrainDIT(nn.Module): padding_mask: Optional[torch.Tensor] = None, **kwargs, ): + orig_shape = list(x.shape) + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial)) x_B_C_T_H_W = x timesteps_B_T = timesteps crossattn_emb = context @@ -882,5 +885,5 @@ class MiniTrainDIT(nn.Module): ) x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) - x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]] return x_B_C_Tt_Hp_Wp diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..f9597de5b 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,6 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x - def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,13 +28,20 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) -def apply_rope1(x: Tensor, freqs_cis: Tensor): - x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - x_out = freqs_cis[..., 0] * x_[..., 0] - x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) +try: + import comfy.quant_ops + apply_rope = comfy.quant_ops.ck.apply_rope + apply_rope1 = comfy.quant_ops.ck.apply_rope1 +except: + logging.warning("No comfy kitchen, using old apply_rope functions.") + def apply_rope1(x: Tensor, freqs_cis: Tensor): + x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) - return x_out.reshape(*x.shape).type_as(x) + x_out = freqs_cis[..., 0] * x_[..., 0] + x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + return x_out.reshape(*x.shape).type_as(x) + + def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index d9e76922f..1f68144e2 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -3,8 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm -import model_management -import model_patcher +import comfy.model_management +import comfy.model_patcher class SRResidualCausalBlock3D(nn.Module): def __init__(self, channels: int): @@ -103,20 +103,20 @@ UPSAMPLERS = { class HunyuanVideo15SRModel(): def __init__(self, model_type, config): - self.load_device = model_management.vae_device() - offload_device = model_management.vae_offload_device() - self.dtype = model_management.vae_dtype(self.load_device) + self.load_device = comfy.model_management.vae_device() + offload_device = comfy.model_management.vae_offload_device() + self.dtype = comfy.model_management.vae_dtype(self.load_device) self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() def resample_latent(self, latent): - model_management.load_model_gpu(self.patcher) + comfy.model_management.load_model_gpu(self.patcher) return self.model(latent.to(self.load_device)) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py new file mode 100644 index 000000000..2c6954ecd --- /dev/null +++ b/comfy/ldm/lightricks/av_model.py @@ -0,0 +1,871 @@ +from typing import Tuple +import torch +import torch.nn as nn +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + AdaLayerNormSingle, + PixArtAlphaTextProjection, + LTXVModel, +) +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +import comfy.ldm.common_dit + +class CompressedTimestep: + """Store video timestep embeddings in compressed form using per-frame indexing.""" + __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') + + def __init__(self, tensor: torch.Tensor, patches_per_frame: int): + """ + tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame + patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression + """ + self.batch_size, num_tokens, self.feature_dim = tensor.shape + + # Check if compression is valid (num_tokens must be divisible by patches_per_frame) + if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + self.patches_per_frame = patches_per_frame + self.num_frames = num_tokens // patches_per_frame + + # Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame + # All patches in a frame are identical, so we only keep the first one + reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) + self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] + else: + # Not divisible or too small - store directly without compression + self.patches_per_frame = 1 + self.num_frames = num_tokens + self.data = tensor + + def expand(self): + """Expand back to original tensor.""" + if self.patches_per_frame == 1: + return self.data + + # [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim] + expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim) + return expanded.reshape(self.batch_size, -1, self.feature_dim) + + def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)): + """Compute ada values on compressed per-frame data, then expand spatially.""" + num_ada_params = scale_shift_table.shape[0] + + # No compression - compute directly + if self.patches_per_frame == 1: + num_tokens = self.data.shape[1] + dim_per_param = self.feature_dim // num_ada_params + reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype) + ada_values = (table_values + reshaped).unbind(dim=2) + return ada_values + + # Compressed: compute on per-frame data then expand spatially + # Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param] + frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to( + device=self.data.device, dtype=self.data.dtype + ) + frame_ada = (table_values + frame_reshaped).unbind(dim=2) + + # Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim] + return tuple( + frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1) + .reshape(batch_size, -1, frame_val.shape[-1]) + for frame_val in frame_ada + ) + +class BasicAVTransformerBlock(nn.Module): + def __init__( + self, + v_dim, + a_dim, + v_heads, + a_heads, + vd_head, + ad_head, + v_context_dim=None, + a_context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + self.attn_precision = attn_precision + + self.attn1 = CrossAttention( + query_dim=v_dim, + heads=v_heads, + dim_head=vd_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn1 = CrossAttention( + query_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.attn2 = CrossAttention( + query_dim=v_dim, + context_dim=v_context_dim, + heads=v_heads, + dim_head=vd_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn2 = CrossAttention( + query_dim=a_dim, + context_dim=a_context_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Video, K,V: Audio + self.audio_to_video_attn = CrossAttention( + query_dim=v_dim, + context_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = CrossAttention( + query_dim=a_dim, + context_dim=v_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.ff = FeedForward( + v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + self.audio_ff = FeedForward( + a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + + self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + self.audio_scale_shift_table = nn.Parameter( + torch.empty(6, a_dim, device=device, dtype=dtype) + ) + + self.scale_shift_table_a2v_ca_audio = nn.Parameter( + torch.empty(5, a_dim, device=device, dtype=dtype) + ) + self.scale_shift_table_a2v_ca_video = nn.Parameter( + torch.empty(5, v_dim, device=device, dtype=dtype) + ) + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) + ): + if isinstance(timestep, CompressedTimestep): + return timestep.expand_for_computation(scale_shift_table, batch_size, indices) + + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ): + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], + batch_size, + scale_shift_timestep, + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], + batch_size, + gate_timestep, + ) + + return (*scale_shift_ada_values, *gate_ada_values) + + def forward( + self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, + v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, + v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + run_vx = transformer_options.get("run_vx", True) + run_ax = transformer_options.get("run_ax", True) + + vx, ax = x + run_ax = run_ax and ax.numel() > 0 + run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 + run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) + + # video + if run_vx: + # video self-attention + vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) + norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa + del vshift_msa, vscale_msa + attn1_out = self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) + del norm_vx + # video cross-attention + vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] + vx.addcmul_(attn1_out, vgate_msa) + del vgate_msa, attn1_out + vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) + + # audio + if run_ax: + # audio self-attention + ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2))) + norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa + del ashift_msa, ascale_msa + attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) + del norm_ax + # audio cross-attention + agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] + ax.addcmul_(attn1_out, agate_msa) + del agate_msa, attn1_out + ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) + + # video - audio cross attention. + if run_a2v or run_v2a: + vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) + ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) + + # audio to video cross attention + if run_a2v: + scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values( + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2] + scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values( + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2] + + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v + del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v + + a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options) + del vx_scaled, ax_scaled + + gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0] + vx.addcmul_(a2v_out, gate_out_a2v) + del gate_out_a2v, a2v_out + + # video to audio cross attention + if run_v2a: + scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values( + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4] + scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values( + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4] + + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a + + v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options) + del ax_scaled, vx_scaled + + gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0] + ax.addcmul_(v2a_out, gate_out_v2a) + del gate_out_v2a, v2a_out + + del vx_norm3, ax_norm3 + + # video feedforward + if run_vx: + vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5)) + vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp + del vshift_mlp, vscale_mlp + + ff_out = self.ff(vx_scaled) + del vx_scaled + + vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0] + vx.addcmul_(ff_out, vgate_mlp) + del vgate_mlp, ff_out + + # audio feedforward + if run_ax: + ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5)) + ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp + del ashift_mlp, ascale_mlp + + ff_out = self.audio_ff(ax_scaled) + del ax_scaled + + agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0] + ax.addcmul_(ff_out, agate_mlp) + del agate_mlp, ff_out + + return vx, ax + + +class LTXAVModel(LTXVModel): + """LTXAV model for audio-video generation.""" + + def __init__( + self, + in_channels=128, + audio_in_channels=128, + cross_attention_dim=4096, + audio_cross_attention_dim=2048, + attention_head_dim=128, + audio_attention_head_dim=64, + num_attention_heads=32, + audio_num_attention_heads=32, + caption_channels=3840, + num_layers=48, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + audio_positional_embedding_max_pos=[20], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier=1000.0, + av_ca_timestep_scale_multiplier=1.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + # Store audio-specific parameters + self.audio_in_channels = audio_in_channels + self.audio_cross_attention_dim = audio_cross_attention_dim + self.audio_attention_head_dim = audio_attention_head_dim + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + + # Calculate audio dimensions + self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + self.audio_out_channels = audio_in_channels + + # Audio-specific constants + self.num_audio_channels = 8 + self.audio_frequency_bins = 16 + + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXAV-specific components.""" + # Audio-specific projections + self.audio_patchify_proj = self.operations.Linear( + self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device + ) + + # Audio-specific AdaLN + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + + num_scale_shift_values = 4 + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + + # Audio caption projection + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXAV.""" + self.transformer_blocks = nn.ModuleList( + [ + BasicAVTransformerBlock( + v_dim=self.inner_dim, + a_dim=self.audio_inner_dim, + v_heads=self.num_attention_heads, + a_heads=self.audio_num_attention_heads, + vd_head=self.attention_head_dim, + ad_head=self.audio_attention_head_dim, + v_context_dim=self.cross_attention_dim, + a_context_dim=self.audio_cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + for _ in range(self.num_layers) + ] + ) + + def _init_output_components(self, device, dtype): + """Initialize output components for LTXAV.""" + # Video output components + super()._init_output_components(device, dtype) + # Audio output components + self.audio_scale_shift_table = nn.Parameter( + torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device) + ) + self.audio_norm_out = self.operations.LayerNorm( + self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.audio_proj_out = self.operations.Linear( + self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device + ) + self.a_patchifier = AudioPatchifier(1, start_end=True) + + def separate_audio_and_video_latents(self, x, audio_length): + """Separate audio and video latents from combined input.""" + # vx = x[:, : self.in_channels] + # ax = x[:, self.in_channels :] + # + # ax = ax.reshape(ax.shape[0], -1) + # ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins] + # + # ax = ax.reshape( + # ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins + # ) + + vx = x[0] + ax = x[1] if len(x) > 1 else torch.zeros( + (vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins), + device=vx.device, dtype=vx.dtype + ) + return vx, ax + + def recombine_audio_and_video_latents(self, vx, ax, target_shape=None): + if ax.numel() == 0: + return vx + else: + return [vx, ax] + """Recombine audio and video latents for output.""" + # if ax.device != vx.device or ax.dtype != vx.dtype: + # logging.warning("Audio and video latents are on different devices or dtypes.") + # ax = ax.to(device=vx.device, dtype=vx.dtype) + # logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}") + # + # ax = ax.reshape(ax.shape[0], -1) + # # pad to f x h x w of the video latents + # divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3] + # if target_shape is None: + # repetitions = math.ceil(ax.shape[-1] / divisor) + # else: + # repetitions = target_shape[1] - vx.shape[1] + # padded_len = repetitions * divisor + # ax = F.pad(ax, (0, padded_len - ax.shape[-1])) + # ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1]) + # return torch.cat([vx, ax], dim=1) + + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXAV - separate audio and video, then patchify.""" + audio_length = kwargs.get("audio_length", 0) + # Separate audio and video latents + vx, ax = self.separate_audio_and_video_latents(x, audio_length) + + has_spatial_mask = False + if denoise_mask is not None: + # check if any frame has spatial variation (inpainting) + for frame_idx in range(denoise_mask.shape[2]): + frame_mask = denoise_mask[0, 0, frame_idx] + if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max(): + has_spatial_mask = True + break + + [vx, v_pixel_coords, additional_args] = super()._process_input( + vx, keyframe_idxs, denoise_mask, **kwargs + ) + additional_args["has_spatial_mask"] = has_spatial_mask + + ax, a_latent_coords = self.a_patchifier.patchify(ax) + ax = self.audio_patchify_proj(ax) + + # additional_args.update({"av_orig_shape": list(x.shape)}) + return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + # TODO: some code reuse is needed here. + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep_scaled = timestep * self.timestep_scale_multiplier + + v_timestep, v_embedded_timestep = self.adaln_single( + timestep_scaled.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] + # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width + orig_shape = kwargs.get("orig_shape") + has_spatial_mask = kwargs.get("has_spatial_mask", None) + v_patches_per_frame = None + if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5: + # orig_shape[3] = height, orig_shape[4] = width (in latent space) + v_patches_per_frame = orig_shape[3] * orig_shape[4] + + # Reshape to [batch_size, num_tokens, dim] and compress for storage + v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) + v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + + # Prepare audio timestep + a_timestep = kwargs.get("a_timestep") + if a_timestep is not None: + a_timestep_scaled = a_timestep * self.timestep_scale_multiplier + a_timestep_flat = a_timestep_scaled.flatten() + timestep_flat = timestep_scaled.flatten() + av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + + # Cross-attention timesteps - compress these too + av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( + a_timestep_flat, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( + timestep_flat, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( + timestep_flat * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( + a_timestep_flat * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Compress cross-attention timesteps (only video side, audio is too small to benefit) + # v_patches_per_frame is None for spatial masks, set for temporal masks or no mask + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]), + CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible + CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible + av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]), + ] + + a_timestep, a_embedded_timestep = self.audio_adaln_single( + a_timestep_flat, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + # Audio timesteps + a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) + a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) + else: + a_timestep = timestep_scaled + a_embedded_timestep = kwargs.get("embedded_timestep") + cross_av_timestep_ss = [] + + return [v_timestep, a_timestep, cross_av_timestep_ss], [ + v_embedded_timestep, + a_embedded_timestep, + ] + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + vx = x[0] + ax = x[1] + v_context, a_context = torch.split( + context, int(context.shape[-1] / 2), len(context.shape) - 1 + ) + + v_context, attention_mask = super()._prepare_context( + v_context, batch_size, vx, attention_mask + ) + if self.audio_caption_projection is not None: + a_context = self.audio_caption_projection(a_context) + a_context = a_context.view(batch_size, -1, ax.shape[-1]) + + return [v_context, a_context], attention_mask + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + v_pixel_coords = pixel_coords[0] + v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype) + + a_latent_coords = pixel_coords[1] + a_pe = self._precompute_freqs_cis( + a_latent_coords, + dim=self.audio_inner_dim, + out_dtype=x_dtype, + max_pos=self.audio_positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.audio_num_attention_heads, + ) + + # calculate positional embeddings for the middle of the token duration, to use in av cross attention layers. + max_pos = max( + self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0] + ) + v_pixel_coords = v_pixel_coords.to(torch.float32) + v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate) + av_cross_video_freq_cis = self._precompute_freqs_cis( + v_pixel_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + av_cross_audio_freq_cis = self._precompute_freqs_cis( + a_latent_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + + return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] + + def _process_transformer_blocks( + self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs + ): + vx = x[0] + ax = x[1] + v_context = context[0] + a_context = context[1] + v_timestep = timestep[0] + a_timestep = timestep[1] + v_pe, av_cross_video_freq_cis = pe[0] + a_pe, av_cross_audio_freq_cis = pe[1] + + ( + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ) = timestep[2] + + """Process transformer blocks for LTXAV.""" + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + # Process transformer blocks + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + + def block_wrap(args): + out = {} + out["img"] = block( + args["img"], + v_context=args["v_context"], + a_context=args["a_context"], + attention_mask=args["attention_mask"], + v_timestep=args["v_timestep"], + a_timestep=args["a_timestep"], + v_pe=args["v_pe"], + a_pe=args["a_pe"], + v_cross_pe=args["v_cross_pe"], + a_cross_pe=args["a_cross_pe"], + v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"], + a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"], + v_cross_gate_timestep=args["v_cross_gate_timestep"], + a_cross_gate_timestep=args["a_cross_gate_timestep"], + transformer_options=args["transformer_options"], + ) + return out + + out = blocks_replace[("double_block", i)]( + { + "img": (vx, ax), + "v_context": v_context, + "a_context": a_context, + "attention_mask": attention_mask, + "v_timestep": v_timestep, + "a_timestep": a_timestep, + "v_pe": v_pe, + "a_pe": a_pe, + "v_cross_pe": av_cross_video_freq_cis, + "a_cross_pe": av_cross_audio_freq_cis, + "v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep, + "a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep, + "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, + "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + ) + vx, ax = out["img"] + else: + vx, ax = block( + (vx, ax), + v_context=v_context, + a_context=a_context, + attention_mask=attention_mask, + v_timestep=v_timestep, + a_timestep=a_timestep, + v_pe=v_pe, + a_pe=a_pe, + v_cross_pe=av_cross_video_freq_cis, + a_cross_pe=av_cross_audio_freq_cis, + v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep, + a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep, + v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, + a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, + transformer_options=transformer_options, + ) + + return [vx, ax] + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + vx = x[0] + ax = x[1] + v_embedded_timestep = embedded_timestep[0] + a_embedded_timestep = embedded_timestep[1] + + # Expand compressed video timestep if needed + if isinstance(v_embedded_timestep, CompressedTimestep): + v_embedded_timestep = v_embedded_timestep.expand() + + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) + + # Process audio output + a_scale_shift_values = ( + self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype) + + a_embedded_timestep[:, :, None] + ) + a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1] + + ax = self.audio_norm_out(ax) + ax = ax * (1 + a_scale) + a_shift + ax = self.audio_proj_out(ax) + + # Unpatchify audio + ax = self.a_patchifier.unpatchify( + ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins + ) + + # Recombine audio and video + original_shape = kwargs.get("av_orig_shape") + return self.recombine_audio_and_video_latents(vx, ax, original_shape) + + def forward( + self, + x, + timestep, + context, + attention_mask=None, + frame_rate=25, + transformer_options={}, + keyframe_idxs=None, + **kwargs, + ): + """ + Forward pass for LTXAV model. + + Args: + x: Combined audio-video input tensor + timestep: Tuple of (video_timestep, audio_timestep) or single timestep + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments including audio_length + + Returns: + Combined audio-video output tensor + """ + # Handle timestep format + if isinstance(timestep, (tuple, list)) and len(timestep) == 2: + v_timestep, a_timestep = timestep + kwargs["a_timestep"] = a_timestep + timestep = v_timestep + else: + kwargs["a_timestep"] = timestep + + # Call parent forward method + return super().forward( + x, + timestep, + context, + attention_mask, + frame_rate, + transformer_options, + keyframe_idxs, + **kwargs, + ) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py new file mode 100644 index 000000000..06f5ada89 --- /dev/null +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -0,0 +1,305 @@ +import math +from typing import Optional + +import comfy.ldm.common_dit +import torch +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + generate_freq_grid_np, + interleaved_freqs_cis, + split_freqs_cis, +) +from torch import nn + + +class BasicTransformerBlock1D(nn.Module): + r""" + A basic Transformer block. + + Parameters: + + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`. + ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer. + attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer. + use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE). + ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer. + """ + + def __init__( + self, + dim, + n_heads, + d_head, + context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + dtype=dtype, + device=device, + operations=operations, + ) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dim_out=dim, + glu=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor: + + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=128, + num_attention_heads=30, + num_layers=2, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[4096], + causal_temporal_positioning=False, + num_learnable_registers: Optional[int] = 128, + dtype=None, + device=None, + operations=None, + split_rope=False, + double_precision_rope=False, + **kwargs, + ): + super().__init__() + self.dtype = dtype + self.out_channels = in_channels + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_rope = split_rope + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = nn.ModuleList( + [ + BasicTransformerBlock1D( + self.inner_dim, + num_attention_heads, + attention_head_dim, + context_dim=cross_attention_dim, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(num_layers) + ] + ) + + inner_dim = num_attention_heads * attention_head_dim + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = nn.Parameter( + torch.rand( + self.num_learnable_registers, inner_dim, dtype=dtype, device=device + ) + * 2.0 + - 1.0 + ) + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(1) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs(self, indices_grid, spacing): + source_dtype = indices_grid.dtype + dtype = ( + torch.float32 + if source_dtype in (torch.bfloat16, torch.float16) + else source_dtype + ) + + fractional_positions = self.get_fractional_positions(indices_grid) + indices = ( + generate_freq_grid_np( + self.positional_embedding_theta, + indices_grid.shape[1], + self.inner_dim, + ) + if self.double_precision_rope + else self.generate_freq_grid(spacing, dtype, fractional_positions.device) + ).to(device=fractional_positions.device) + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs + + def generate_freq_grid(self, spacing, dtype, device): + dim = self.inner_dim + theta = self.positional_embedding_theta + n_pos_dims = 1 + n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6 + start = 1 + end = theta + + if spacing == "exp": + indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem)) + indices = indices.to(dtype=dtype, device=device) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace( + start, end, dim // n_elem, device=device, dtype=dtype + ) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // n_elem, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + return indices + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dim = self.inner_dim + n_elem = 2 # 2 because of cos and sin + freqs = self.precompute_freqs(indices_grid, spacing) + if self.split_rope: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis( + freqs, pad_size, self.num_attention_heads + ) + else: + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + + if self.num_learnable_registers: + num_registers_duplications = math.ceil( + max(1024, hidden_states.shape[1]) / self.num_learnable_registers + ) + learnable_registers = torch.tile( + self.learnable_registers.to(hidden_states), (num_registers_duplications, 1) + ) + + hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) + + if attention_mask is not None: + attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device) + + indices_grid = torch.arange( + hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device + ) + indices_grid = indices_grid[None, None, :] + freqs_cis = self.precompute_freqs_cis(indices_grid) + + # 2. Blocks + for block_idx, block in enumerate(self.transformer_1d_blocks): + hidden_states = block( + hidden_states, attention_mask=attention_mask, pe=freqs_cis + ) + + # 3. Output + # if self.output_scale is not None: + # hidden_states = hidden_states / self.output_scale + + hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + return hidden_states, attention_mask diff --git a/comfy/ldm/lightricks/latent_upsampler.py b/comfy/ldm/lightricks/latent_upsampler.py new file mode 100644 index 000000000..78ed7653f --- /dev/null +++ b/comfy/ldm/lightricks/latent_upsampler.py @@ -0,0 +1,292 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError( + f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}" + ) + return mapping[float(scale)] + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + + +class BlurDownsample(nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int): + super().__init__() + assert dims in (2, 3) + assert stride >= 1 and isinstance(stride, int) + self.dims = dims + self.stride = stride + + # 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized + k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (5,5) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + def _apply_2d(x2d: torch.Tensor) -> torch.Tensor: + # x2d: (B, C, H, W) + B, C, H, W = x2d.shape + weight = self.kernel.expand(C, 1, 5, 5) # depthwise + x2d = F.conv2d( + x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C + ) + return x2d + + if self.dims == 2: + return _apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = _apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + +class SpatialRationalResampler(nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = nn.Conv2d( + mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1 + ) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(nn.Module): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler( + mid_channels=mid_channels, scale=self.spatial_scale + ) + else: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + if isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + spatial_scale=config.get("spatial_scale", 2.0), + rational_resampler=config.get("rational_resampler", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + "spatial_scale": self.spatial_scale, + "rational_resampler": self.rational_resampler, + } diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 593f7940f..d61e19d6e 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,13 +1,47 @@ +from abc import ABC, abstractmethod +from enum import Enum +import functools +import math +from typing import Dict, Optional, Tuple + +from einops import rearrange +import numpy as np import torch from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -import math -from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords -from comfy.ldm.flux.math import apply_rope1 + +def _log_base(x, base): + return np.log(x) / np.log(base) + +class LTXRopeType(str, Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + KEY = "rope_type" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.INTERLEAVED + return cls(kwargs.get(cls.KEY, default)) + + +class LTXFrequenciesPrecision(str, Enum): + FLOAT32 = "float32" + FLOAT64 = "float64" + + KEY = "frequencies_precision" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.FLOAT32 + return cls(kwargs.get(cls.KEY, default)) + def get_timestep_embedding( timesteps: torch.Tensor, @@ -39,9 +73,7 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) @@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module): post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, - dtype=None, device=None, operations=None, + dtype=None, + device=None, + operations=None, ): super().__init__() @@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module): time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device + ) if post_act_fn is None: self.post_act = None @@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ - def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, + embedding_dim, + size_emb_dim, + use_additional_conditions: bool = False, + dtype=None, + device=None, + operations=None, + ): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations + ) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) @@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module): use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ - def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None + ): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( - embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations + embedding_dim, + size_emb_dim=embedding_dim // 3, + use_additional_conditions=use_additional_conditions, + dtype=dtype, + device=device, + operations=operations, ) self.silu = nn.SiLU() - self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device) + self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device) def forward( self, @@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module): embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep + class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module): Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ - def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None): + def __init__( + self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None + ): super().__init__() if out_features is None: out_features = hidden_size - self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device) + self.linear_1 = operations.Linear( + in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device + ) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device + ) def forward(self, caption): hidden_states = self.linear_1(caption) @@ -222,23 +282,68 @@ class GELU_approx(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None): super().__init__() inner_dim = int(dim * mult) project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) + project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) ) def forward(self, x): return self.net(x) +def apply_rotary_emb(input_tensor, freqs_cis): + cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1] + split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False + return ( + apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs) + if split_pe else + apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs) + ) + +def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + +def apply_split_rotary_emb(input_tensor, cos, sin): + needs_reshape = False + if input_tensor.ndim != 4 and cos.ndim == 4: + B, H, T, _ = cos.shape + input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2) + needs_reshape = True + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + output = split_input * cos.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input) + output = rearrange(output, "... d r -> ... (d r)") + return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output + class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): super().__init__() inner_dim = dim_head * heads context_dim = query_dim if context_dim is None else context_dim @@ -254,9 +359,11 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) - self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) + ) - def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): + def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -266,8 +373,8 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) - k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) + q = apply_rotary_emb(q, pe) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) @@ -277,14 +384,34 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + ): super().__init__() self.attn_precision = attn_precision - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) @@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module): return x def get_fractional_positions(indices_grid, max_pos): + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' fractional_positions = torch.stack( - [ - indices_grid[:, i] / max_pos[i] - for i in range(3) - ], - dim=-1, + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + axis=-1, ) return fractional_positions -def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 - device = indices_grid.device +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None): + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + _log_base(start, theta), + _log_base(end, theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + +def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device): + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + device=device, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + +def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid): + if use_middle_indices_grid: + assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2) + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) - indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 + indices = indices.to(device=fractional_positions.device) - # Compute frequencies and apply cos/sin - freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) - cos_vals = freqs.cos().repeat_interleave(2, dim=-1) - sin_vals = freqs.sin().repeat_interleave(2, dim=-1) + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs - # Pad if dim is not divisible by 6 - if dim % 6 != 0: - padding_size = dim % 6 - cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) - sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) +def interleaved_freqs_cis(freqs, pad_size): + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq - # Reshape and extract one value per pair (since repeat_interleave duplicates each value) - cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] - sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] +def split_freqs_cis(freqs, pad_size, num_attention_heads): + cos_freq = freqs.cos() + sin_freq = freqs.sin() - # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension - freqs_cis = torch.stack([ - torch.stack([cos_vals, -sin_vals], dim=-1), - torch.stack([sin_vals, cos_vals], dim=-1) - ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) - return freqs_cis + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + # Reshape freqs to be compatible with multi-head attention + B , T, half_HD = cos_freq.shape -class LTXVModel(torch.nn.Module): - def __init__(self, - in_channels=128, - cross_attention_dim=2048, - attention_head_dim=64, - num_attention_heads=32, + cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) + sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) - caption_channels=4096, - num_layers=28, + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq +class LTXBaseModel(torch.nn.Module, ABC): + """ + Abstract base class for LTX models (Lightricks Transformer models). - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - causal_temporal_positioning=False, - vae_scale_factors=(8, 32, 32), - dtype=None, device=None, operations=None, **kwargs): + This class defines the common interface and shared functionality for all LTX models, + including LTXV (video) and LTXAV (audio-video) variants. + """ + + def __init__( + self, + in_channels: int, + cross_attention_dim: int, + attention_head_dim: int, + num_attention_heads: int, + caption_channels: int, + num_layers: int, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list = [20, 2048, 2048], + causal_temporal_positioning: bool = False, + vae_scale_factors: tuple = (8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): super().__init__() self.generator = None self.vae_scale_factors = vae_scale_factors + self.use_middle_indices_grid = use_middle_indices_grid self.dtype = dtype - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.cross_attention_dim = cross_attention_dim + self.attention_head_dim = attention_head_dim + self.num_attention_heads = num_attention_heads + self.caption_channels = caption_channels + self.num_layers = num_layers + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_positional_embedding = LTXRopeType.from_dict(kwargs) + self.freq_grid_generator = ( + generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64 + else generate_freq_grid_pytorch + ) self.causal_temporal_positioning = causal_temporal_positioning + self.operations = operations + self.timestep_scale_multiplier = timestep_scale_multiplier - self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + # Common dimensions + self.inner_dim = num_attention_heads * attention_head_dim + self.out_channels = in_channels + + # Initialize common components + self._init_common_components(device, dtype) + + # Initialize model-specific components + self._init_model_components(device, dtype, **kwargs) + + # Initialize transformer blocks + self._init_transformer_blocks(device, dtype, **kwargs) + + # Initialize output components + self._init_output_components(device, dtype) + + def _init_common_components(self, device, dtype): + """Initialize components common to all LTX models + - patchify_proj: Linear projection for patchifying input + - adaln_single: AdaLN layer for timestep embedding + - caption_projection: Linear projection for caption embedding + """ + self.patchify_proj = self.operations.Linear( + self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device + ) self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations + self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - # self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device) - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, ) + @abstractmethod + def _init_model_components(self, device, dtype, **kwargs): + """Initialize model-specific components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_output_components(self, device, dtype): + """Initialize output components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input data. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): + """Process transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output data. Must be implemented by subclasses.""" + pass + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + + return timestep, embedded_timestep + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + """Prepare context for transformer blocks.""" + if self.caption_projection is not None: + context = self.caption_projection(context) + context = context.view(batch_size, -1, x.shape[-1]) + + return context, attention_mask + + def _precompute_freqs_cis( + self, + indices_grid, + dim, + out_dtype, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=False, + num_attention_heads=32, + ): + split_mode = self.split_positional_embedding == LTXRopeType.SPLIT + indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if split_mode: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + """Prepare positional embeddings.""" + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + pe = self._precompute_freqs_cis( + fractional_coords, + dim=self.inner_dim, + out_dtype=x_dtype, + max_pos=self.positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + ) + return pe + + def _prepare_attention_mask(self, attention_mask, x_dtype): + """Prepare attention mask.""" + if attention_mask is not None and not torch.is_floating_point(attention_mask): + attention_mask = (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + return attention_mask + + def forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers( + comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options + ), + ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs) + + def _forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Internal forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + if isinstance(x, list): + input_dtype = x[0].dtype + batch_size = x[0].shape[0] + else: + input_dtype = x.dtype + batch_size = x.shape[0] + # Process input + merged_args = {**transformer_options, **kwargs} + x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args) + merged_args.update(additional_args) + + # Prepare timestep and context + timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) + + # Prepare attention mask and positional embeddings + attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) + pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + + # Process transformer blocks + x = self._process_transformer_blocks( + x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args + ) + + # Process output + x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args) + return x + + +class LTXVModel(LTXBaseModel): + """LTXV model for video generation.""" + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=64, + num_attention_heads=32, + caption_channels=4096, + num_layers=28, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXV-specific components.""" + # No additional components needed for LTXV beyond base class + pass + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXV.""" self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, - num_attention_heads, - attention_head_dim, - context_dim=cross_attention_dim, - # attn_precision=attn_precision, - dtype=dtype, device=device, operations=operations + self.num_attention_heads, + self.attention_head_dim, + context_dim=self.cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, ) - for d in range(num_layers) + for _ in range(self.num_layers) ] ) + def _init_output_components(self, device, dtype): + """Initialize output components for LTXV.""" self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) - self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) - - self.patchifier = SymmetricPatchifier(1) - - def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._forward, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs) - - def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - patches_replace = transformer_options.get("patches_replace", {}) - - orig_shape = list(x.shape) + self.norm_out = self.operations.LayerNorm( + self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) + self.patchifier = SymmetricPatchifier(1, start_end=True) + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXV.""" + additional_args = {"orig_shape": list(x.shape)} x, latent_coords = self.patchifier.patchify(x) pixel_coords = latent_to_pixel_coords( latent_coords=latent_coords, @@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module): causal_fix=self.causal_temporal_positioning, ) + grid_mask = None if keyframe_idxs is not None: - pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs + additional_args.update({ "orig_patchified_shape": list(x.shape)}) + denoise_mask = self.patchifier.patchify(denoise_mask)[0] + grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] + additional_args.update({"grid_mask": grid_mask}) + x = x[:, grid_mask, :] + pixel_coords = pixel_coords[:, :, grid_mask, ...] - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] + pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs x = self.patchify_proj(x) - timestep = timestep * 1000.0 - - if attention_mask is not None and not torch.is_floating_point(attention_mask): - attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max - - pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) - - batch_size = x.shape[0] - timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=x.dtype, - ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - timestep = timestep.view(batch_size, -1, timestep.shape[-1]) - embedded_timestep = embedded_timestep.view( - batch_size, -1, embedded_timestep.shape[-1] - ) - - # 2. Blocks - if self.caption_projection is not None: - batch_size = x.shape[0] - context = self.caption_projection(context) - context = context.view( - batch_size, -1, x.shape[-1] - ) + return x, pixel_coords, additional_args + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): + """Process transformer blocks for LTXV.""" + patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: + def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) @@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module): transformer_options=transformer_options, ) - # 3. Output + return x + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output for LTXV.""" + # Apply scale-shift modulation scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + x = self.norm_out(x) - # Modulation - x = torch.addcmul(x, x, scale).add_(shift) + x = x * (1 + scale) + shift x = self.proj_out(x) + if keyframe_idxs is not None: + grid_mask = kwargs["grid_mask"] + orig_patchified_shape = kwargs["orig_patchified_shape"] + full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) + full_x[:, grid_mask, :] = x + x = full_x + # Unpatchify to restore original dimensions + orig_shape = kwargs["orig_shape"] x = self.patchifier.unpatchify( latents=x, output_height=orig_shape[3], diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py index 4b9972b9f..8f9a41186 100644 --- a/comfy/ldm/lightricks/symmetric_patchifier.py +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -21,20 +21,23 @@ def latent_to_pixel_coords( Returns: Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. """ + shape = [1] * latent_coords.ndim + shape[1] = -1 pixel_coords = ( latent_coords - * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + * torch.tensor(scale_factors, device=latent_coords.device).view(*shape) ) if causal_fix: # Fix temporal scale for first frame to 1 due to causality - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) return pixel_coords class Patchifier(ABC): - def __init__(self, patch_size: int): + def __init__(self, patch_size: int, start_end: bool=False): super().__init__() self._patch_size = (1, patch_size, patch_size) + self.start_end = start_end @abstractmethod def patchify( @@ -71,11 +74,23 @@ class Patchifier(ABC): torch.arange(0, latent_width, self._patch_size[2], device=device), indexing="ij", ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size + latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0) + delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None] + latent_sample_coords_end = latent_sample_coords_start + delta + + latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_start = rearrange( + latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size ) + if self.start_end: + latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_end = rearrange( + latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size + ) + + latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1) + else: + latent_coords = latent_sample_coords_start return latent_coords @@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier): q=self._patch_size[2], ) return latents + + +class AudioPatchifier(Patchifier): + def __init__(self, patch_size: int, + sample_rate=16000, + hop_length=160, + audio_latent_downsample_factor=4, + is_causal=True, + start_end=False, + shift = 0 + ): + super().__init__(patch_size, start_end=start_end) + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + + def copy_with_shift(self, shift): + return AudioPatchifier( + self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor, + self.is_causal, self.start_end, shift + ) + + def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device): + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + if self.is_causal: + audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0) + return audio_mel_frame * self.hop_length / self.sample_rate + + + def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # audio_latents: (batch, channels, time, freq) + b, _, t, _ = audio_latents.shape + audio_latents = rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device) + audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + if self.start_end: + audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device) + audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1) + else: + audio_latents_timings = audio_latents_start_timings + return audio_latents, audio_latents_timings + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor: + # audio_latents: (batch, time, freq * channels) + audio_latents = rearrange( + audio_latents, "b t (c f) -> b c t f", c=channels, f=freq + ) + return audio_latents diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py new file mode 100644 index 000000000..55a074661 --- /dev/null +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -0,0 +1,279 @@ +import json +from dataclasses import dataclass +import math +import torch +import torchaudio + +import comfy.model_management +import comfy.model_patcher +import comfy.utils as utils +from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( + CausalityAxis, + CausalAudioAutoencoder, +) +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +@dataclass(frozen=True) +class AudioVAEComponentConfig: + """Container for model component configuration extracted from metadata.""" + + autoencoder: dict + vocoder: dict + + @classmethod + def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig": + assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE" + + raw_config = metadata["config"] + if isinstance(raw_config, str): + parsed_config = json.loads(raw_config) + else: + parsed_config = raw_config + + audio_config = parsed_config.get("audio_vae") + vocoder_config = parsed_config.get("vocoder") + + assert audio_config is not None, "Audio VAE config is required for audio VAE" + assert vocoder_config is not None, "Vocoder config is required for audio VAE" + + return cls(autoencoder=audio_config, vocoder=vocoder_config) + + +class ModelDeviceManager: + """Manages device placement and GPU residency for the composed model.""" + + def __init__(self, module: torch.nn.Module): + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.vae_offload_device() + self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) + + def ensure_model_loaded(self) -> None: + comfy.model_management.free_memory( + self.patcher.model_size(), + self.patcher.load_device, + ) + comfy.model_management.load_model_gpu(self.patcher) + + def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(self.patcher.load_device) + + @property + def load_device(self): + return self.patcher.load_device + + +class AudioLatentNormalizer: + """Applies per-channel statistics in patch space and restores original layout.""" + + def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module): + self.patchifier = patchfier + self.statistics = statistics_processor + + def normalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + normalized = self.statistics.normalize(patched) + return self.patchifier.unpatchify(normalized, channels=channels, freq=freq) + + def denormalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + denormalized = self.statistics.un_normalize(patched) + return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq) + + +class AudioPreprocessor: + """Prepares raw waveforms for the autoencoder by matching training conditions.""" + + def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int): + self.target_sample_rate = target_sample_rate + self.mel_bins = mel_bins + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + + def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor: + if source_rate == self.target_sample_rate: + return waveform + return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) + + def waveform_to_mel( + self, waveform: torch.Tensor, waveform_sample_rate: int, device + ) -> torch.Tensor: + waveform = self.resample(waveform, waveform_sample_rate) + + mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=self.target_sample_rate, + n_fft=self.n_fft, + win_length=self.n_fft, + hop_length=self.mel_hop_length, + f_min=0.0, + f_max=self.target_sample_rate / 2.0, + n_mels=self.mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ).to(device) + + mel = mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + return mel.permute(0, 1, 3, 2).contiguous() + + +class AudioVAE(torch.nn.Module): + """High-level Audio VAE wrapper exposing encode and decode entry points.""" + + def __init__(self, state_dict: dict, metadata: dict): + super().__init__() + + component_config = AudioVAEComponentConfig.from_metadata(metadata) + + vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) + vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) + + self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) + self.vocoder = Vocoder(config=component_config.vocoder) + + self.autoencoder.load_state_dict(vae_sd, strict=False) + self.vocoder.load_state_dict(vocoder_sd, strict=False) + + autoencoder_config = self.autoencoder.get_config() + self.normalizer = AudioLatentNormalizer( + AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=autoencoder_config["sampling_rate"], + hop_length=autoencoder_config["mel_hop_length"], + is_causal=autoencoder_config["is_causal"], + ), + self.autoencoder.per_channel_statistics, + ) + + self.preprocessor = AudioPreprocessor( + target_sample_rate=autoencoder_config["sampling_rate"], + mel_bins=autoencoder_config["mel_bins"], + mel_hop_length=autoencoder_config["mel_hop_length"], + n_fft=autoencoder_config["n_fft"], + ) + + self.device_manager = ModelDeviceManager(self) + + def encode(self, audio: dict) -> torch.Tensor: + """Encode a waveform dictionary into normalized latent tensors.""" + + waveform = audio["waveform"] + waveform_sample_rate = audio["sample_rate"] + input_device = waveform.device + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + waveform = self.device_manager.move_to_load_device(waveform) + expected_channels = self.autoencoder.encoder.in_channels + if waveform.shape[1] != expected_channels: + if waveform.shape[1] == 1: + waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:]) + else: + raise ValueError( + f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" + ) + + mel_spec = self.preprocessor.waveform_to_mel( + waveform, waveform_sample_rate, device=self.device_manager.load_device + ) + + latents = self.autoencoder.encode(mel_spec) + posterior = DiagonalGaussianDistribution(latents) + latent_mode = posterior.mode() + + normalized = self.normalizer.normalize(latent_mode) + return normalized.to(input_device) + + def decode(self, latents: torch.Tensor) -> torch.Tensor: + """Decode normalized latent tensors into an audio waveform.""" + original_shape = latents.shape + + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + latents = self.device_manager.move_to_load_device(latents) + latents = self.normalizer.denormalize(latents) + + target_shape = self.target_shape_from_latents(original_shape) + mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) + + waveform = self.run_vocoder(mel_spec) + return self.device_manager.move_to_load_device(waveform) + + def target_shape_from_latents(self, latents_shape): + batch, _, time, _ = latents_shape + target_length = time * LATENT_DOWNSAMPLE_FACTOR + if self.autoencoder.causality_axis != CausalityAxis.NONE: + target_length -= LATENT_DOWNSAMPLE_FACTOR - 1 + return ( + batch, + self.autoencoder.decoder.out_ch, + target_length, + self.autoencoder.mel_bins, + ) + + def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int: + return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second) + + def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor: + audio_channels = self.autoencoder.decoder.out_ch + vocoder_input = mel_spec.transpose(2, 3) + + if audio_channels == 1: + vocoder_input = vocoder_input.squeeze(1) + elif audio_channels != 2: + raise ValueError(f"Unsupported audio_channels: {audio_channels}") + + return self.vocoder(vocoder_input) + + @property + def sample_rate(self) -> int: + return int(self.autoencoder.sampling_rate) + + @property + def mel_hop_length(self) -> int: + return int(self.autoencoder.mel_hop_length) + + @property + def mel_bins(self) -> int: + return int(self.autoencoder.mel_bins) + + @property + def latent_channels(self) -> int: + return int(self.autoencoder.decoder.z_channels) + + @property + def latent_frequency_bins(self) -> int: + return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR) + + @property + def latents_per_second(self) -> float: + return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR + + @property + def output_sample_rate(self) -> int: + output_rate = getattr(self.vocoder, "output_sample_rate", None) + if output_rate is not None: + return int(output_rate) + upsample_factor = getattr(self.vocoder, "upsample_factor", None) + if upsample_factor is None: + raise AttributeError( + "Vocoder is missing upsample_factor; cannot infer output sample rate" + ) + return int(self.sample_rate * upsample_factor / self.mel_hop_length) + + def memory_required(self, input_shape): + return self.device_manager.patcher.model_size() diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py new file mode 100644 index 000000000..f12b9bb53 --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -0,0 +1,909 @@ +from __future__ import annotations +import torch +from torch import nn +from torch.nn import functional as F +from typing import Optional +from enum import Enum +from .pixel_norm import PixelNorm +import comfy.ops +import logging + +ops = comfy.ops.disable_weight_init + + +class StringConvertibleEnum(Enum): + """ + Base enum class that provides string-to-enum conversion functionality. + + This mixin adds a str_to_enum() class method that handles conversion from + strings, None, or existing enum instances with case-insensitive matching. + """ + + @classmethod + def str_to_enum(cls, value): + """ + Convert a string, enum instance, or None to the appropriate enum member. + + Args: + value: Can be an enum instance of this class, a string, or None + + Returns: + Enum member of this class + + Raises: + ValueError: If the value cannot be converted to a valid enum member + """ + # Already an enum instance of this class + if isinstance(value, cls): + return value + + # None maps to NONE member if it exists + if value is None: + if hasattr(cls, "NONE"): + return cls.NONE + raise ValueError(f"{cls.__name__} does not have a NONE member to map None to") + + # String conversion (case-insensitive) + if isinstance(value, str): + value_lower = value.lower() + + # Try to match against enum values + for member in cls: + # Handle members with None values + if member.value is None: + if value_lower == "none": + return member + # Handle members with string values + elif isinstance(member.value, str) and member.value.lower() == value_lower: + return member + + # Build helpful error message with valid values + valid_values = [] + for member in cls: + if member.value is None: + valid_values.append("none") + elif isinstance(member.value, str): + valid_values.append(member.value) + + raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}") + + raise ValueError( + f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. " + f"Expected string, None, or {cls.__name__} instance." + ) + + +class AttentionType(StringConvertibleEnum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class CausalityAxis(StringConvertibleEnum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +def Normalize(in_channels, *, num_groups=32, normtype="group"): + if normtype == "group": + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif normtype == "pixel": + return PixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {normtype}") + + +class CausalConv2d(nn.Module): + """ + A causal 2D convolution. + + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = nn.modules.utils._pair(kernel_size) + dilation = nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=None, + dilation=1, + groups=1, + bias=True, + causality_axis: Optional[CausalityAxis] = None, +): + """ + Create a 2D convolution layer that can be either causal or non-causal. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + if isinstance(kernel_size, int): + padding = kernel_size // 2 + else: + padding = tuple(k // 2 for k in kernel_size) + return ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + # (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + norm_type="group", + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, normtype=norm_type) + self.non_linearity = nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = ops.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels, normtype=norm_type) + self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla", norm_type="group"): + # Convert string to enum if needed + attn_type = AttentionType.str_to_enum(attn_type) + + if attn_type != AttentionType.NONE: + logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels") + else: + logging.info(f"making identity attention with {in_channels} in_channels") + + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return nn.Identity(in_channels) + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + self.non_linearity = nn.SiLU() + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + def forward(self, x): + """ + Forward pass through the encoder. + + Args: + x: Input tensor of shape [batch, channels, time, n_mels] + + Returns: + Encoded latent representation + """ + feature_maps = [self.conv_in(x)] + + # Process each resolution level (from high to low resolution) + for resolution_level in range(self.num_resolutions): + # Apply residual blocks at current resolution level + for block_idx in range(self.num_res_blocks): + # Apply ResNet block with optional timestep embedding + current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None) + + # Apply attention if configured for this resolution level + if len(self.down[resolution_level].attn) > 0: + current_features = self.down[resolution_level].attn[block_idx](current_features) + + # Store processed features + feature_maps.append(current_features) + + # Downsample spatial dimensions (except at the final resolution level) + if resolution_level != self.num_resolutions - 1: + downsampled_features = self.down[resolution_level].downsample(feature_maps[-1]) + feature_maps.append(downsampled_features) + + # === MIDDLE PROCESSING PHASE === + # Take the lowest resolution features for middle processing + bottleneck_features = feature_maps[-1] + + # Apply first middle ResNet block + bottleneck_features = self.mid.block_1(bottleneck_features, temb=None) + + # Apply middle attention block + bottleneck_features = self.mid.attn_1(bottleneck_features) + + # Apply second middle ResNet block + bottleneck_features = self.mid.block_2(bottleneck_features, temb=None) + + # === OUTPUT PHASE === + # Normalize the bottleneck features + output_features = self.norm_out(bottleneck_features) + + # Apply non-linearity (SiLU activation) + output_features = self.non_linearity(output_features) + + # Final convolution to produce latent representation + # [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels] + return self.conv_out(output_features) + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = out_ch + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.norm_type = norm_type + self.z_channels = z_channels + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis) + + self.non_linearity = nn.SiLU() + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis) + + def _adjust_output_shape(self, decoded_output, target_shape): + """ + Adjust output shape to match target dimensions for variable-length audio. + + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: Target shape tuple (batch, channels, time, frequency) + + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + _, target_channels, target_time, target_freq = target_shape + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def get_config(self): + return { + "ch": self.ch, + "out_ch": self.out_ch, + "ch_mult": self.ch_mult, + "num_res_blocks": self.num_res_blocks, + "in_channels": self.in_channels, + "resolution": self.resolution, + "z_channels": self.z_channels, + } + + def forward(self, latent_features, target_shape=None): + """ + Decode latent features back to audio spectrograms. + + Args: + latent_features: Encoded latent representation of shape (batch, channels, height, width) + target_shape: Optional target output shape (batch, channels, time, frequency) + If provided, output will be cropped/padded to match this shape + + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder" + + # Transform latent features to decoder's internal feature dimension + hidden_features = self.conv_in(latent_features) + + # Middle processing + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + # Upsampling + # Progressively increase spatial resolution from lowest to highest + for resolution_level in reversed(range(self.num_resolutions)): + # Apply residual blocks at current resolution level + for block_index in range(self.num_res_blocks + 1): + hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None) + + if len(self.up[resolution_level].attn) > 0: + hidden_features = self.up[resolution_level].attn[block_index](hidden_features) + + if resolution_level != 0: + hidden_features = self.up[resolution_level].upsample(hidden_features) + + # Output + if self.give_pre_end: + # Return intermediate features before final processing (for debugging/analysis) + decoded_output = hidden_features + else: + # Standard output path: normalize, activate, and convert to output channels + # Final normalization layer + hidden_features = self.norm_out(hidden_features) + + # Apply SiLU (Swish) activation function + hidden_features = self.non_linearity(hidden_features) + + # Final convolution to map to output channels (typically 2 for stereo audio) + decoded_output = self.conv_out(hidden_features) + + # Optional tanh activation to bound output values to [-1, 1] range + if self.tanh_out: + decoded_output = torch.tanh(decoded_output) + + # Adjust shape for audio data + if target_shape is not None: + decoded_output = self._adjust_output_shape(decoded_output, target_shape) + + return decoded_output + + +class processor(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("std-of-means", torch.empty(128)) + self.register_buffer("mean-of-means", torch.empty(128)) + + def un_normalize(self, x): + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x): + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +class CausalAudioAutoencoder(nn.Module): + def __init__(self, config=None): + super().__init__() + + if config is None: + config = self._guess_config() + + # Extract encoder and decoder configs from the new format + model_config = config.get("model", {}).get("params", {}) + variables_config = config.get("variables", {}) + + self.sampling_rate = variables_config.get( + "sampling_rate", + model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + ) + encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) + decoder_config = model_config.get("decoder", encoder_config) + + # Load mel spectrogram parameters + self.mel_bins = encoder_config.get("mel_bins", 64) + self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + + # Store causality configuration at VAE level (not just in encoder internals) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) + self.is_causal = self.causality_axis == CausalityAxis.HEIGHT + + self.encoder = Encoder(**encoder_config) + self.decoder = Decoder(**decoder_config) + + self.per_channel_statistics = processor() + + def _guess_config(self): + encoder_config = { + # Required parameters - based on ltx-video-av-1679000 model metadata + "ch": 128, + "out_ch": 8, + "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] + "num_res_blocks": 2, + "attn_resolutions": [], # Based on metadata: empty list, no attention + "dropout": 0.0, + "resamp_with_conv": True, + "in_channels": 2, # stereo + "resolution": 256, + "z_channels": 8, + "double_z": True, + "attn_type": "vanilla", + "mid_block_add_attention": False, # Based on metadata: false + "norm_type": "pixel", + "causality_axis": "height", # Based on metadata + "mel_bins": 64, # Based on metadata: mel_bins = 64 + } + + decoder_config = { + # Inherits encoder config, can override specific params + **encoder_config, + "out_ch": 2, # Stereo audio output (2 channels) + "give_pre_end": False, + "tanh_out": False, + } + + config = { + "_class_name": "CausalAudioAutoencoder", + "sampling_rate": 16000, + "model": { + "params": { + "encoder": encoder_config, + "decoder": decoder_config, + } + }, + } + + return config + + def get_config(self): + return { + "sampling_rate": self.sampling_rate, + "mel_bins": self.mel_bins, + "mel_hop_length": self.mel_hop_length, + "n_fft": self.n_fft, + "causality_axis": self.causality_axis.value, + "is_causal": self.is_causal, + } + + def encode(self, x): + return self.encoder(x) + + def decode(self, x, target_shape=None): + return self.decoder(x, target_shape=target_shape) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 70d612e86..b8341edbc 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -1,11 +1,11 @@ from typing import Tuple, Union +import threading import torch import torch.nn as nn import comfy.ops ops = comfy.ops.disable_weight_init - class CausalConv3d(nn.Module): def __init__( self, @@ -42,23 +42,34 @@ class CausalConv3d(nn.Module): padding_mode=spatial_padding_mode, groups=groups, ) + self.temporal_cache_state={} def forward(self, x, causal: bool = True): - if causal: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, self.time_kernel_size - 1, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x), dim=2) - else: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - last_frame_pad = x[:, :, -1:, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) - x = self.conv(x) - return x + tid = threading.get_ident() + + cached, is_end = self.temporal_cache_state.get(tid, (None, False)) + if cached is None: + padding_length = self.time_kernel_size - 1 + if not causal: + padding_length = padding_length // 2 + if x.shape[2] == 0: + return x + cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1)) + pieces = [ cached, x ] + if is_end and not causal: + pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) + + needs_caching = not is_end + if needs_caching and x.shape[2] >= self.time_kernel_size - 1: + needs_caching = False + self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + + x = torch.cat(pieces, dim=2) + + if needs_caching: + self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + + return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :] @property def weight(self): diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 75ed069ad..cbfdf412d 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -1,4 +1,5 @@ from __future__ import annotations +import threading import torch from torch import nn from functools import partial @@ -6,12 +7,35 @@ import math from einops import rearrange from typing import List, Optional, Tuple, Union from .conv_nd_factory import make_conv_nd, make_linear_nd +from .causal_conv3d import CausalConv3d from .pixel_norm import PixelNorm from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings import comfy.ops +from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def mark_conv3d_ended(module): + tid = threading.get_ident() + for _, m in module.named_modules(): + if isinstance(m, CausalConv3d): + current = m.temporal_cache_state.get(tid, (None, False)) + m.temporal_cache_state[tid] = (current[0], True) + +def split2(tensor, split_point, dim=2): + return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim) + +def add_exchange_cache(dest, cache_in, new_input, dim=2): + if dest is not None: + if cache_in is not None: + cache_to_dest = min(dest.shape[dim], cache_in.shape[dim]) + lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim) + lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim) + lead_in_dest.add_(lead_in_source) + body, new_input = split2(new_input, dest.shape[dim], dim) + dest.add_(body) + return torch_cat_if_needed([cache_in, new_input], dim=dim) + class Encoder(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. @@ -205,7 +229,7 @@ class Encoder(nn.Module): self.gradient_checkpointing = False - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) @@ -254,6 +278,22 @@ class Encoder(nn.Module): return sample + def forward(self, *args, **kwargs): + #No encoder support so just flag the end so it doesnt use the cache. + mark_conv3d_ended(self) + try: + return self.forward_orig(*args, **kwargs) + finally: + tid = threading.get_ident() + for _, module in self.named_modules(): + # ComfyUI doesn't thread this kind of stuff today, but just in case + # we key on the thread to make it thread safe. + tid = threading.get_ident() + if hasattr(module, "temporal_cache_state"): + module.temporal_cache_state.pop(tid, None) + + +MAX_CHUNK_SIZE=(128 * 1024 ** 2) class Decoder(nn.Module): r""" @@ -341,18 +381,6 @@ class Decoder(nn.Module): timestep_conditioning=timestep_conditioning, spatial_padding_mode=spatial_padding_mode, ) - elif block_name == "attn_res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - attention_head_dim=block_params["attention_head_dim"], - spatial_padding_mode=spatial_padding_mode, - ) elif block_name == "res_x_y": output_channel = output_channel // block_params.get("multiplier", 2) block = ResnetBlock3D( @@ -428,8 +456,9 @@ class Decoder(nn.Module): ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: - def forward( + def forward_orig( self, sample: torch.FloatTensor, timestep: Optional[torch.Tensor] = None, @@ -437,6 +466,7 @@ class Decoder(nn.Module): r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] + mark_conv3d_ended(self.conv_in) sample = self.conv_in(sample, causal=self.causal) checkpoint_fn = ( @@ -445,24 +475,12 @@ class Decoder(nn.Module): else lambda x: x ) - scaled_timestep = None + timestep_shift_scale = None if self.timestep_conditioning: assert ( timestep is not None ), "should pass timestep with timestep_conditioning=True" scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device) - - for up_block in self.up_blocks: - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) - - sample = self.conv_norm_out(sample) - - if self.timestep_conditioning: embedded_timestep = self.last_time_embedder( timestep=scaled_timestep.flatten(), resolution=None, @@ -483,16 +501,62 @@ class Decoder(nn.Module): embedded_timestep.shape[-2], embedded_timestep.shape[-1], ) - shift, scale = ada_values.unbind(dim=1) - sample = sample * (1 + scale) + shift + timestep_shift_scale = ada_values.unbind(dim=1) - sample = self.conv_act(sample) - sample = self.conv_out(sample, causal=self.causal) + output = [] + + def run_up(idx, sample, ended): + if idx >= len(self.up_blocks): + sample = self.conv_norm_out(sample) + if timestep_shift_scale is not None: + shift, scale = timestep_shift_scale + sample = sample * (1 + scale) + shift + sample = self.conv_act(sample) + if ended: + mark_conv3d_ended(self.conv_out) + sample = self.conv_out(sample, causal=self.causal) + if sample is not None and sample.shape[2] > 0: + output.append(sample) + return + + up_block = self.up_blocks[idx] + if (ended): + mark_conv3d_ended(up_block) + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + if sample is None or sample.shape[2] == 0: + return + + total_bytes = sample.numel() * sample.element_size() + num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE + samples = torch.chunk(sample, chunks=num_chunks, dim=2) + + for chunk_idx, sample1 in enumerate(samples): + run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1) + + run_up(0, sample, True) + sample = torch.cat(output, dim=2) sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) return sample + def forward(self, *args, **kwargs): + try: + return self.forward_orig(*args, **kwargs) + finally: + for _, module in self.named_modules(): + #ComfyUI doesn't thread this kind of stuff today, but just incase + #we key on the thread to make it thread safe. + tid = threading.get_ident() + if hasattr(module, "temporal_cache_state"): + module.temporal_cache_state.pop(tid, None) + class UNetMidBlock3D(nn.Module): """ @@ -663,8 +727,22 @@ class DepthToSpaceUpsample(nn.Module): ) self.residual = residual self.out_channels_reduction_factor = out_channels_reduction_factor + self.temporal_cache_state = {} def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None): + tid = threading.get_ident() + cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True)) + y = self.conv(x, causal=causal) + y = rearrange( + y, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv: + y = y[:, :, 1:, :, :] + drop_first_conv = False if self.residual: # Reshape and duplicate the input to match the output shape x_in = rearrange( @@ -676,21 +754,20 @@ class DepthToSpaceUpsample(nn.Module): ) num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor x_in = x_in.repeat(1, num_repeat, 1, 1, 1) - if self.stride[0] == 2: + if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res: x_in = x_in[:, :, 1:, :, :] - x = self.conv(x, causal=causal) - x = rearrange( - x, - "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - if self.stride[0] == 2: - x = x[:, :, 1:, :, :] - if self.residual: - x = x + x_in - return x + drop_first_res = False + + if y.shape[2] == 0: + y = None + + cached = add_exchange_cache(y, cached, x_in, dim=2) + self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res) + + else: + self.temporal_cache_state[tid] = (None, drop_first_conv, False) + + return y class LayerNorm(nn.Module): def __init__(self, dim, eps, elementwise_affine=True) -> None: @@ -807,6 +884,8 @@ class ResnetBlock3D(nn.Module): torch.randn(4, in_channels) / in_channels**0.5 ) + self.temporal_cache_state={} + def _feed_spatial_noise( self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor ) -> torch.FloatTensor: @@ -880,9 +959,12 @@ class ResnetBlock3D(nn.Module): input_tensor = self.conv_shortcut(input_tensor) - output_tensor = input_tensor + hidden_states + tid = threading.get_ident() + cached = self.temporal_cache_state.get(tid, None) + cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2) + self.temporal_cache_state[tid] = cached - return output_tensor + return hidden_states def patchify(x, patch_size_hw, patch_size_t=1): diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py new file mode 100644 index 000000000..b1f15f2c5 --- /dev/null +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -0,0 +1,213 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import comfy.ops +import numpy as np + +ops = comfy.ops.disable_weight_init + +LRELU_SLOPE = 0.1 + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + +class Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + + """ + + def __init__(self, config=None): + super(Vocoder, self).__init__() + + if config is None: + config = self.get_default_config() + + resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) + upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_initial_channel = config.get("upsample_initial_channel", 1024) + stereo = config.get("stereo", True) + resblock = config.get("resblock", "1") + + self.output_sample_rate = config.get("output_sample_rate") + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 + self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + ops.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock_class(ch, k, d)) + + out_channels = 2 if stereo else 1 + self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + + self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + + def get_default_config(self): + """Generate default configuration for the vocoder.""" + + config = { + "resblock_kernel_sizes": [3, 7, 11], + "upsample_rates": [6, 5, 2, 2, 2], + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "upsample_initial_channel": 1024, + "stereo": True, + "resblock": "1", + } + + return config + + def forward(self, x): + """ + Forward pass of the vocoder. + + Args: + x: Input spectrogram tensor. Can be: + - 3D: (batch_size, channels, time_steps) for mono + - 4D: (batch_size, 2, channels, time_steps) for stereo + + Returns: + Audio tensor of shape (batch_size, out_channels, audio_length) + """ + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index afbab2ac7..77d1abc97 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -13,10 +13,53 @@ from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension +import comfy.utils -def modulate(x, scale): - return x * (1 + scale.unsqueeze(1)) +def invert_slices(slices, length): + sorted_slices = sorted(slices) + result = [] + current = 0 + + for start, end in sorted_slices: + if current < start: + result.append((current, start)) + current = max(current, end) + + if current < length: + result.append((current, length)) + + return result + + +def modulate(x, scale, timestep_zero_index=None): + if timestep_zero_index is None: + return x * (1 + scale.unsqueeze(1)) + else: + scale = (1 + scale.unsqueeze(1)) + actual_batch = scale.size(0) // 2 + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= scale[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= scale[:actual_batch] + return x + + +def apply_gate(gate, x, timestep_zero_index=None): + if timestep_zero_index is None: + return gate * x + else: + actual_batch = gate.size(0) // 2 + + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= gate[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= gate[:actual_batch] + return x ############################################################################# # Core NextDiT Model # @@ -258,6 +301,7 @@ class JointTransformerBlock(nn.Module): x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + timestep_zero_index=None, transformer_options={}, ): """ @@ -276,18 +320,18 @@ class JointTransformerBlock(nn.Module): assert adaln_input is not None scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) - x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2( clamp_fp16(self.attention( - modulate(self.attention_norm1(x), scale_msa), + modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index), x_mask, freqs_cis, transformer_options=transformer_options, - )) + ))), timestep_zero_index=timestep_zero_index ) - x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2( clamp_fp16(self.feed_forward( - modulate(self.ffn_norm1(x), scale_mlp), - )) + modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index), + ))), timestep_zero_index=timestep_zero_index ) else: assert adaln_input is None @@ -345,13 +389,37 @@ class FinalLayer(nn.Module): ), ) - def forward(self, x, c): + def forward(self, x, c, timestep_zero_index=None): scale = self.adaLN_modulation(c) - x = modulate(self.norm_final(x), scale) + x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index) x = self.linear(x) return x +def pad_zimage(feats, pad_token, pad_tokens_multiple): + pad_extra = (-feats.shape[1]) % pad_tokens_multiple + return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra + + +def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}): + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = start_t + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() + return x_pos_ids + + class NextDiT(nn.Module): """ Diffusion model with a Transformer backbone. @@ -378,10 +446,12 @@ class NextDiT(nn.Module): time_scale=1.0, pad_tokens_multiple=None, clip_text_dim=None, + siglip_feat_dim=None, image_model=None, device=None, dtype=None, operations=None, + **kwargs, ) -> None: super().__init__() self.dtype = dtype @@ -491,6 +561,41 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) + + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + siglip_feat_dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.siglip_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + # This norm final is in the lumina 2.0 code but isn't actually used for anything. # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) @@ -531,70 +636,168 @@ class NextDiT(nn.Module): imgs = torch.stack(imgs, dim=0) return imgs - def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} - ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: - bsz = len(x) - pH = pW = self.patch_size - device = x[0].device - orig_x = x - - if self.pad_tokens_multiple is not None: - pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None): + if cap_feats is not None: + cap_feats = self.cap_embedder(cap_feats) + cap_feats_len = cap_feats.shape[1] + if self.pad_tokens_multiple is not None: + cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple) + else: + cap_feats_len = 0 + cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) - cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset + embeds = (cap_feats,) + freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + + def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}): + bsz = 1 + pH = pW = self.patch_size + device = x.device + embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype) + + if (not omni) or self.siglip_embedder is None: + cap_feats_len = embeds[0].shape[1] + offset + embeds += (None,) + freqs_cis += (None,) + else: + cap_feats_len += offset + if siglip_feats is not None: + b, h, w, c = siglip_feats.shape + siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c) + siglip_feats = self.siglip_embedder(siglip_feats) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + siglip_pos_ids[:, :, 0] = cap_feats_len + 2 + siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten() + siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten() + if self.siglip_pad_token is not None: + siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check + siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra)) + else: + if self.siglip_pad_token is not None: + siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + + if siglip_feats is None: + embeds += (None,) + freqs_cis += (None,) + else: + embeds += (siglip_feats,) + freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),) B, C, H, W = x.shape x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - - rope_options = transformer_options.get("rope_options", None) - h_scale = 1.0 - w_scale = 1.0 - h_start = 0 - w_start = 0 - if rope_options is not None: - h_scale = rope_options.get("scale_y", 1.0) - w_scale = rope_options.get("scale_x", 1.0) - - h_start = rope_options.get("shift_y", 0.0) - w_start = rope_options.get("shift_x", 0.0) - - H_tokens, W_tokens = H // pH, W // pW - x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) - x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 - x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() - x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - + x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options) if self.pad_tokens_multiple is not None: - pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + embeds += (x,) + freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1 + + + def patchify_and_embed( + self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={} + ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: + bsz = x.shape[0] + cap_mask = None # TODO? + main_siglip = None + orig_x = x + + embeds = ([], [], []) + freqs_cis = ([], [], []) + leftover_cap = [] + + start_t = 0 + omni = len(ref_latents) > 0 + if omni: + for i, ref in enumerate(ref_latents): + if i < len(ref_contexts): + ref_con = ref_contexts[i] + else: + ref_con = None + if i < len(siglip_feats): + sig_feat = siglip_feats[i] + else: + sig_feat = None + + out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options) + for i, e in enumerate(out[0]): + if e is not None: + embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + leftover_cap = ref_contexts[len(ref_latents):] + + H, W = x.shape[-2], x.shape[-1] + img_sizes = [(H, W)] * bsz + out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options) + img_len = out[0][-1].shape[1] + cap_len = out[0][0].shape[1] + for i, e in enumerate(out[0]): + if e is not None: + e = comfy.utils.repeat_to_batch_size(e, bsz) + embeds[i].append(e) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + + for cap in leftover_cap: + out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype) + cap_len += out[0][0].shape[1] + embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz)) + freqs_cis[0].append(out[1][0]) + start_t += out[2] patches = transformer_options.get("patches", {}) # refine context + cap_feats = torch.cat(embeds[0], dim=1) + cap_freqs_cis = torch.cat(freqs_cis[0], dim=1) for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + + feats = (cap_feats,) + fc = (cap_freqs_cis,) + + if omni and len(embeds[1]) > 0: + siglip_mask = None + siglip_feats_combined = torch.cat(embeds[1], dim=1) + siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1) + if self.siglip_refiner is not None: + for layer in self.siglip_refiner: + siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options) + feats += (siglip_feats_combined,) + fc += (siglip_feats_freqs_cis,) padded_img_mask = None + x = torch.cat(embeds[-1], dim=1) + fc_x = torch.cat(freqs_cis[-1], dim=1) + if omni: + timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])] + else: + timestep_zero_index = None + x_input = x for i, layer in enumerate(self.noise_refiner): - x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "noise_refiner" in patches: for p in patches["noise_refiner"]: - out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) + out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) if "img" in out: x = out["img"] - padded_full_embed = torch.cat((cap_feats, x), dim=1) + padded_full_embed = torch.cat(feats + (x,), dim=1) + if timestep_zero_index is not None: + ind = padded_full_embed.shape[1] - x.shape[1] + timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])] + timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1])) + mask = None - img_sizes = [(H, W)] * bsz - l_effective_cap_len = [cap_feats.shape[1]] * bsz - return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz + return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -604,7 +807,11 @@ class NextDiT(nn.Module): ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) # def forward(self, x, t, cap_feats, cap_mask): - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -619,8 +826,6 @@ class NextDiT(nn.Module): t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t - cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - if self.clip_text_pooled_proj is not None: pooled = kwargs.get("clip_text_pooled", None) if pooled is not None: @@ -632,7 +837,7 @@ class NextDiT(nn.Module): patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) transformer_options["total_blocks"] = len(self.layers) @@ -640,7 +845,7 @@ class NextDiT(nn.Module): img_input = img for i, layer in enumerate(self.layers): transformer_options["block_index"] = i - img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) @@ -649,8 +854,7 @@ class NextDiT(nn.Module): if "txt" in out: img[:, :cap_size[0]] = out["txt"] - img = self.final_layer(img, adaln_input) + img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index) img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] - return -img diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 1ae3ef034..5a22ef030 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae(): import xformers.ops def torch_cat_if_needed(xl, dim): + xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: return torch.cat(xl, dim) - else: + elif len(xl) == 1: return xl[0] + else: + return None def get_timestep_embedding(timesteps, embedding_dim): """ diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 00c597535..6eb744286 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -170,8 +170,14 @@ class Attention(nn.Module): joint_query = apply_rope1(joint_query, image_rotary_emb) joint_key = apply_rope1(joint_key, image_rotary_emb) + if encoder_hidden_states_mask is not None: + attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) + attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask + else: + attn_mask = None + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, - attention_mask, transformer_options=transformer_options, + attn_mask, transformer_options=transformer_options, skip_reshape=True) txt_attn_output = joint_hidden_states[:, :seq_txt, :] @@ -430,6 +436,9 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = context encoder_hidden_states_mask = attention_mask + if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask): + encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max + hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4216ce831..ea123acb4 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module): x(Tensor): Shape [B, L, num_heads, C / num_heads] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ + patches = transformer_options.get("patches", {}) + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim def qkv_fn_q(x): @@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module): transformer_options=transformer_options, ) + if "attn1_patch" in patches: + for p in patches["attn1_patch"]: + x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) + x = self.o(x) return x @@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module): """ # assert e.dtype == torch.float32 + patches = transformer_options.get("patches", {}) + if e.ndim < 4: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) else: @@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module): # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + + if "attn2_patch" in patches: + for p in patches["attn2_patch"]: + x = p({"x": x, "transformer_options": transformer_options}) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -488,7 +501,7 @@ class WanModel(torch.nn.Module): self.blocks = nn.ModuleList([ wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) - for _ in range(num_layers) + for i in range(num_layers) ]) # head @@ -541,6 +554,7 @@ class WanModel(torch.nn.Module): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings @@ -738,6 +752,7 @@ class VaceWanModel(WanModel): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py new file mode 100644 index 000000000..c9dd98c4d --- /dev/null +++ b/comfy/ldm/wan/model_multitalk.py @@ -0,0 +1,500 @@ +import torch +from einops import rearrange, repeat +import comfy +from comfy.ldm.modules.attention import optimized_attention + + +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8): + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q.transpose(1, 2) * scale + + B, H, x_seqlens, K = visual_q.shape + + x_ref_attn_maps = [] + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask.view(1, 1, 1, -1) + + x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype) + chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens) + + for i in range(0, x_seqlens, chunk_size): + end_i = min(i + chunk_size, x_seqlens) + + attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens + + # Apply softmax + attn_max = attn_chunk.max(dim=-1, keepdim=True).values + attn_chunk = (attn_chunk - attn_max).exp() + attn_sum = attn_chunk.sum(dim=-1, keepdim=True) + attn_chunk = attn_chunk / (attn_sum + 1e-8) + + # Apply mask and sum + masked_attn = attn_chunk * ref_target_mask + x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8) + + del attn_chunk, masked_attn + + # Average across heads + x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens + x_ref_attn_maps.append(x_ref_attnmap) + + del visual_q, ref_k + + return torch.cat(x_ref_attn_maps, dim=0) + +def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2): + """Args: + query (torch.tensor): B M H K + key (torch.tensor): B M H K + shape (tuple): (N_t, N_h, N_w) + ref_target_masks: [B, N_h * N_w] + """ + + N_t, N_h, N_w = shape + + x_seqlens = N_h * N_w + ref_k = ref_k[:, :x_seqlens] + _, seq_lens, heads, _ = visual_q.shape + class_num, _ = ref_target_masks.shape + x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q) + + split_chunk = heads // split_num + + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map( + visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_target_masks + ) + x_ref_attn_maps += x_ref_attn_maps_perhead + + return x_ref_attn_maps / split_num + + +def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): + source_min, source_max = source_range + new_min, new_max = target_range + normalized = (column - source_min) / (source_max - source_min + epsilon) + scaled = normalized * (new_max - new_min) + new_min + return scaled + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def get_audio_embeds(encoded_audio, audio_start, audio_end): + audio_embs = [] + human_num = len(encoded_audio) + audio_frames = encoded_audio[0].shape[0] + + indices = (torch.arange(4 + 1) - 2) * 1 + + for human_idx in range(human_num): + if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence + pad_len = audio_end - audio_frames + pad_shape = list(encoded_audio[human_idx].shape) + pad_shape[0] = pad_len + pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1))) + encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0) + else: + encoded_audio_in = encoded_audio[human_idx] + center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1) + audio_emb = encoded_audio_in[center_indices].unsqueeze(0) + audio_embs.append(audio_emb) + + return torch.cat(audio_embs, dim=0) + + +def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end): + audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end) + + first_frame_audio_emb_s = audio_embs[:, :1, ...] + latter_frame_audio_emb = audio_embs[:, 1:, ...] + latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4) + + middle_index = audio_proj.seq_len // 2 + + latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] + latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] + latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] + latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) + + audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) + audio_emb = torch.cat(audio_emb.split(1), dim=2) + + return audio_emb + + +class RotaryPositionalEmbedding1D(torch.nn.Module): + def __init__(self, + head_dim, + ): + super().__init__() + self.head_dim = head_dim + self.base = 10000 + + def precompute_freqs_cis_1d(self, pos_indices): + freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) + freqs = freqs.to(pos_indices.device) + freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + return freqs + + def forward(self, x, pos_indices): + freqs_cis = self.precompute_freqs_cis_1d(pos_indices) + + x_ = x.float() + + freqs_cis = freqs_cis.float().to(x.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + x_ = (x_ * cos) + (rotate_half(x_) * sin) + + return x_.type_as(x) + +class SingleStreamAttention(torch.nn.Module): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + device=None, dtype=None, operations=None + ) -> None: + super().__init__() + self.dim = dim + self.encoder_hidden_states_dim = encoder_hidden_states_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor: + N_t, N_h, N_w = shape + + expected_tokens = N_t * N_h * N_w + actual_tokens = x.shape[1] + x_extra = None + + if actual_tokens != expected_tokens: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + + B = x.shape[0] + S = N_h * N_w + x = x.view(B * N_t, S, self.dim) + + # get q for hidden_state + q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim) + + # get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim) + kv = self.kv_linear(encoder_hidden_states) + encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2) + + #print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128]) + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # linear transform + x = self.proj(x.reshape(B * N_t, S, self.dim)) + x = x.view(B, N_t * S, self.dim) + + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + +class SingleStreamMultiAttention(SingleStreamAttention): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + class_range: int = 24, + class_interval: int = 4, + device=None, dtype=None, operations=None + ) -> None: + super().__init__( + dim=dim, + encoder_hidden_states_dim=encoder_hidden_states_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + device=device, + dtype=dtype, + operations=operations + ) + + # Rotary-embedding layout parameters + self.class_interval = class_interval + self.class_range = class_range + self.max_humans = self.class_range // self.class_interval + + # Constant bucket used for background tokens + self.rope_bak = int(self.class_range // 2) + + self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) + + def forward( + self, + x: torch.Tensor, + encoder_hidden_states: torch.Tensor, + shape=None, + x_ref_attn_map=None + ) -> torch.Tensor: + encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device) + human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1 + # Single-speaker fall-through + if human_num <= 1: + return super().forward(x, encoder_hidden_states, shape) + + N_t, N_h, N_w = shape + + x_extra = None + if x.shape[0] * N_t != encoder_hidden_states.shape[0]: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + + # Query projection + B, N, C = x.shape + q = self.q_linear(x) + q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + # Use `class_range` logic for 2 speakers + rope_h1 = (0, self.class_interval) + rope_h2 = (self.class_range - self.class_interval, self.class_range) + rope_bak = int(self.class_range // 2) + + # Normalize and scale attention maps for each speaker + max_values = x_ref_attn_map.max(1).values[:, None, None] + min_values = x_ref_attn_map.min(1).values[:, None, None] + max_min_values = torch.cat([max_values, min_values], dim=2) + + human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() + human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() + + human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1) + human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2) + back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device) + + # Token-wise speaker dominance + max_indices = x_ref_attn_map.argmax(dim=0) + normalized_map = torch.stack([human1, human2, back], dim=1) + normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices] + + # Apply rotary to Q + q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + q = self.rope_1d(q, normalized_pos) + q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Keys / Values + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + encoder_k, encoder_v = encoder_kv.unbind(0) + + # Rotary for keys – assign centre of each speaker bucket to its context tokens + per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device) + per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2 + per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2 + encoder_pos = torch.cat([per_frame] * N_t, dim=0) + + encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + encoder_k = self.rope_1d(encoder_k, encoder_pos) + encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Final attention + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # Linear projection + x = x.reshape(B, N, C) + x = self.proj(x) + + # Restore original layout + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + + +class MultiTalkAudioProjModel(torch.nn.Module): + def __init__( + self, + seq_len: int = 5, + seq_len_vf: int = 12, + blocks: int = 12, + channels: int = 768, + intermediate_dim: int = 512, + out_dim: int = 768, + context_tokens: int = 32, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels + self.input_dim_vf = seq_len_vf * blocks * channels + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.out_dim = out_dim + + # define multiple linear layers + self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype) + self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype) + self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype) + self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype) + + def forward(self, audio_embeds, audio_embeds_vf): + video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] + B, _, _, S, C = audio_embeds.shape + + # process audio of first frame + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + # process audio of latter frame + audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") + batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape + audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) + + # first projection + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) + audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) + audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) + audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) + batch_size_c, N_t, C_a = audio_embeds_c.shape + audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) + + # second projection + audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) + + context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim) + + # normalization and reshape + context_tokens = self.norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class WanMultiTalkAttentionBlock(torch.nn.Module): + def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None): + super().__init__() + self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations) + self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True) + + +class MultiTalkGetAttnMapPatch: + def __init__(self, ref_target_masks=None): + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + x = kwargs["x"] + + if self.ref_target_masks is not None: + x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device)) + transformer_options["x_ref_attn_map"] = x_ref_attn_map + return x + + +class MultiTalkCrossAttnPatch: + def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None): + self.model_patch = model_patch + self.audio_scale = audio_scale + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + block_idx = transformer_options.get("block_index", None) + x = kwargs["x"] + if block_idx is None: + return torch.zeros_like(x) + + audio_embeds = transformer_options.get("audio_embeds") + x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None) + + norm_x = self.model_patch.model.blocks[block_idx].norm_x(x) + x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn( + norm_x, audio_embeds.to(x.dtype), + shape=transformer_options["grid_sizes"], + x_ref_attn_map=x_ref_attn_map + ) + x = x + x_audio * self.audio_scale + return x + + def models(self): + return [self.model_patch] + +class MultiTalkApplyModelWrapper: + def __init__(self, init_latents): + self.init_latents = init_latents + + def __call__(self, executor, x, *args, **kwargs): + x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x) + samples = executor(x, *args, **kwargs) + return samples + + +class InfiniteTalkOuterSampleWrapper: + def __init__(self, motion_frames_latent, model_patch, is_extend=False): + self.motion_frames_latent = motion_frames_latent + self.model_patch = model_patch + self.is_extend = is_extend + + def __call__(self, executor, *args, **kwargs): + model_patcher = executor.class_obj.model_patcher + model_options = executor.class_obj.model_options + process_latent_in = model_patcher.model.process_latent_in + + # for InfiniteTalk, model input first latent(s) need to always be replaced on every step + if self.motion_frames_latent is not None: + wrappers = model_options["transformer_options"]["wrappers"] + w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}) + w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))] + + # run the sampling process + result = executor(*args, **kwargs) + + # insert motion frames before decoding + if self.is_extend: + overlap = self.motion_frames_latent.shape[2] + result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2) + + return result + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + if self.motion_frames_latent is not None: + self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype) + return self diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 08315f1a8..fd125ceed 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from comfy.ldm.modules.diffusionmodules.model import vae_attention +from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed import comfy.ops ops = comfy.ops.disable_weight_init @@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) + self._padding = 2 * self.padding[0] + self.padding = (0, self.padding[1], self.padding[2]) def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): if cache_list is not None: cache_x = cache_list[cache_idx] cache_list[cache_idx] = None - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] + if cache_x is None and x.shape[2] == 1: + #Fast path - the op will pad for use by truncating the weight + #and save math on a pile of zeros. + return super().forward(x, autopad="causal_zero") + + if self._padding > 0: + padding_needed = self._padding + if cache_x is not None: + cache_x = cache_x.to(x.device) + padding_needed = max(0, padding_needed - cache_x.shape[2]) + padding_shape = list(x.shape) + padding_shape[2] = padding_needed + padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype) + x = torch_cat_if_needed([padding, cache_x, x], dim=2) del cache_x - x = F.pad(x, padding) return super().forward(x) @@ -472,10 +479,12 @@ class WanVAE(nn.Module): def encode(self, x): conv_idx = [0] - feat_map = [None] * count_conv3d(self.decoder) ## cache t = x.shape[2] iter_ = 1 + (t - 1) // 4 + feat_map = None + if iter_ > 1: + feat_map = [None] * count_conv3d(self.decoder) ## 对encode输入的x,按时间拆分为1、4、4、4.... for i in range(iter_): conv_idx = [0] @@ -495,10 +504,11 @@ class WanVAE(nn.Module): def decode(self, z): conv_idx = [0] - feat_map = [None] * count_conv3d(self.decoder) # z: [b,c,t,h,w] - iter_ = z.shape[2] + feat_map = None + if iter_ > 1: + feat_map = [None] * count_conv3d(self.decoder) x = self.conv2(z) for i in range(iter_): conv_idx = [0] diff --git a/comfy/lora.py b/comfy/lora.py index 2ed0acb9d..7b31d055c 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -260,6 +260,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer + key_map[k[:-len(".weight")]] = to #DiffSynth lora format for k in sdk: hidden_size = model.model_config.unet_config.get("hidden_size", 0) if k.endswith(".weight") and ".linear1." in k: @@ -322,6 +323,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["diffusion_model.{}".format(key_lora)] = to key_map["transformer.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + key_map[key_lora] = to if isinstance(model, comfy.model_base.Kandinsky5): for k in sdk: diff --git a/comfy/memory_management.py b/comfy/memory_management.py new file mode 100644 index 000000000..858bd4cc7 --- /dev/null +++ b/comfy/memory_management.py @@ -0,0 +1,81 @@ +import math +import torch +from typing import NamedTuple + +from comfy.quant_ops import QuantizedTensor + +class TensorGeometry(NamedTuple): + shape: any + dtype: torch.dtype + + def element_size(self): + info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype) + return info.bits // 8 + + def numel(self): + return math.prod(self.shape) + +def tensors_to_geometries(tensors, dtype=None): + geometries = [] + for t in tensors: + if t is None or isinstance(t, QuantizedTensor): + geometries.append(t) + continue + tdtype = t.dtype + if hasattr(t, "_model_dtype"): + tdtype = t._model_dtype + if dtype is not None: + tdtype = dtype + geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype)) + return geometries + +def vram_aligned_size(tensor): + if isinstance(tensor, list): + return sum([vram_aligned_size(t) for t in tensor]) + + if isinstance(tensor, QuantizedTensor): + inner_tensors, _ = tensor.__tensor_flatten__() + return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ]) + + if tensor is None: + return 0 + + size = tensor.numel() * tensor.element_size() + aligment_req = 1024 + return (size + aligment_req - 1) // aligment_req * aligment_req + +def interpret_gathered_like(tensors, gathered): + offset = 0 + dest_views = [] + + if gathered.dim() != 1 or gathered.element_size() != 1: + raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})") + + for tensor in tensors: + + if tensor is None: + dest_views.append(None) + continue + + if isinstance(tensor, QuantizedTensor): + inner_tensors, qt_ctx = tensor.__tensor_flatten__() + templates = { attr: getattr(tensor, attr) for attr in inner_tensors } + else: + templates = { "data": tensor } + + actuals = {} + for attr, template in templates.items(): + size = template.numel() * template.element_size() + if offset + size > gathered.numel(): + raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ") + actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape) + offset += vram_aligned_size(template) + + if isinstance(tensor, QuantizedTensor): + dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0)) + else: + dest_views.append(actuals["data"]) + + return dest_views + +aimdo_allocator = None diff --git a/comfy/model_base.py b/comfy/model_base.py index c4f3c0639..85acdb66a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1 import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import comfy.ldm.lightricks.av_model from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -48,6 +49,7 @@ import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model +import comfy.ldm.anima.model import comfy.model_management import comfy.patcher_extension @@ -147,6 +149,8 @@ class BaseModel(torch.nn.Module): self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) + comfy.model_management.archive_model_dtypes(self.diffusion_model) + self.adm_channels = unet_config.get("adm_in_channels", None) if self.adm_channels is None: self.adm_channels = 0 @@ -297,7 +301,7 @@ class BaseModel(torch.nn.Module): return out - def load_model_weights(self, sd, unet_prefix=""): + def load_model_weights(self, sd, unet_prefix="", assign=False): to_load = {} keys = list(sd.keys()) for k in keys: @@ -305,7 +309,7 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign) if len(m) > 0: logging.warning("unet missing: {}".format(m)) @@ -320,7 +324,7 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -328,10 +332,7 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) - - unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -774,8 +775,8 @@ class StableAudio1(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] @@ -946,7 +947,7 @@ class GenmoMochi(BaseModel): class LTXV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -977,6 +978,60 @@ class LTXV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class LTXAV(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + + audio_denoise_mask = None + if denoise_mask is not None and "latent_shapes" in kwargs: + denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"]) + if len(denoise_mask) > 1: + audio_denoise_mask = denoise_mask[1] + denoise_mask = denoise_mask[0] + + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + if audio_denoise_mask is not None: + out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask) + + keyframe_idxs = kwargs.get("keyframe_idxs", None) + if keyframe_idxs is not None: + out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + + latent_shapes = kwargs.get("latent_shapes", None) + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + + return out + + def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): + v_timestep = timestep + a_timestep = timestep + + if denoise_mask is not None: + v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0] + if audio_denoise_mask is not None: + a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0] + + return v_timestep, a_timestep + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) @@ -1092,9 +1147,31 @@ class CosmosPredict2(BaseModel): sigma = (sigma / (sigma + 1)) return latent_image / (1.0 - sigma) +class Anima(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + t5xxl_ids = kwargs.get("t5xxl_ids", None) + t5xxl_weights = kwargs.get("t5xxl_weights", None) + device = kwargs["device"] + if cross_attn is not None: + if t5xxl_ids is not None: + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device)) + if t5xxl_weights is not None: + cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + + if cross_attn.shape[1] < 512: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1])) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1114,6 +1191,35 @@ class Lumina2(BaseModel): if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni + if clip_vision_outputs is not None and len(clip_vision_outputs) > 0: + sigfeats = [] + for clip_vision_output in clip_vision_outputs: + if clip_vision_output is not None: + image_size = clip_vision_output.image_sizes[0] + shape = clip_vision_output.last_hidden_state.shape + sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1])) + if len(sigfeats) > 0: + out['siglip_feats'] = comfy.conds.CONDList(sigfeats) + + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_contexts = kwargs.get("reference_latents_text_embeds", None) + if ref_contexts is not None: + out['ref_contexts'] = comfy.conds.CONDList(ref_contexts) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out class WAN21(BaseModel): @@ -1471,6 +1577,9 @@ class QwenImage(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 539e296ed..8cea16e50 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -237,6 +237,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["vec_in_dim"] = None + dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"]) + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma @@ -251,7 +253,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "chroma_radiance" dit_config["in_channels"] = 3 dit_config["out_channels"] = 3 - dit_config["patch_size"] = 16 + dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1) dit_config["nerf_hidden_size"] = 64 dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_depth"] = 4 @@ -305,7 +307,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} - dit_config["image_model"] = "ltxv" + dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv" dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape dit_config["attention_head_dim"] = shape[0] // 32 @@ -442,8 +444,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) dit_config["z_image_modulation"] = True dit_config["time_scale"] = 1000.0 + try: + dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42 + except Exception: + pass if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: dit_config["pad_tokens_multiple"] = 32 + sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None) + if sig_weight is not None: + dit_config["siglip_feat_dim"] = sig_weight.shape[0] return dit_config @@ -545,6 +554,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 dit_config = {} dit_config["image_model"] = "cosmos_predict2" + if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "anima" dit_config["max_img_h"] = 240 dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 diff --git a/comfy/model_management.py b/comfy/model_management.py index 2501cecb7..6b1166b94 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -22,11 +22,17 @@ from enum import Enum from comfy.cli_args import args, PerformanceFeature import torch import sys -import importlib import platform import weakref import gc import os +from contextlib import nullcontext +import comfy.memory_management +import comfy.utils +import comfy.quant_ops + +import comfy_aimdo.torch +import comfy_aimdo.model_vbar class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -349,15 +355,27 @@ try: except: rocm_version = (6, -1) + def aotriton_supported(gpu_arch): + path = torch.__path__[0] + path = os.path.join(os.path.join(path, "lib"), "aotriton.images") + gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path)))) + if gpu_arch in gfx: + return True + if "{}x".format(gpu_arch[:-1]) in gfx: + return True + if "{}xx".format(gpu_arch[:-2]) in gfx: + return True + return False + logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. + if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): - if any((a in arch) for a in ["gfx1201"]): + if any((a in arch) for a in ["gfx1200", "gfx1201"]): ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 @@ -456,7 +474,7 @@ def module_size(module): sd = module.state_dict() for k in sd: t = sd[k] - module_mem += t.nelement() * t.element_size() + module_mem += t.nbytes return module_mem class LoadedModel: @@ -567,9 +585,15 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: + import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 + def get_free_ram(): + return comfy.windows.get_free_ram() +else: + def get_free_ram(): + return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -581,7 +605,7 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[]): +def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -596,15 +620,23 @@ def free_memory(memory_required, device, keep_loaded=[]): for x in sorted(can_unload): i = x[-1] - memory_to_free = None + memory_to_free = 1e32 + ram_to_free = 1e32 if not DISABLE_SMART_MEMORY: - free_mem = get_free_memory(device) - if free_mem > memory_required: - break - memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + memory_to_free = memory_required - get_free_memory(device) + ram_to_free = ram_required - get_free_ram() + + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + memory_to_free = 0 + if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) + if ram_to_free > 0: + logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") + current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -639,7 +671,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] + free_for_dynamic=True for x in models: + if not x.is_dynamic(): + free_for_dynamic = False loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) @@ -665,19 +700,25 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() + total_memory_required = {} + total_ram_required = {} for loaded_model in models_to_load: total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + #x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we + #want to do. + #FIXME: This should subtract off the to_load current pin consumption. + total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) for device in total_memory_required: if device != torch.device("cpu"): free_mem = get_free_memory(device) if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) + models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic) logging.info("{} models unloaded.".format(len(models_l))) for loaded_model in models_to_load: @@ -721,6 +762,9 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False + + reset_cast_buffers() + for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): @@ -738,6 +782,11 @@ def cleanup_models_gc(): logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) +def archive_model_dtypes(model): + for name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + def cleanup_models(): to_delete = [] @@ -781,7 +830,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev: + if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None: return torch_dev else: return cpu_dev @@ -1040,6 +1089,53 @@ def current_stream(device): return None stream_counters = {} + +STREAM_CAST_BUFFERS = {} +LARGEST_CASTED_WEIGHT = (None, 0) + +def get_cast_buffer(offload_stream, device, size, ref): + global LARGEST_CASTED_WEIGHT + + if offload_stream is not None: + wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) + else: + wf_context = nullcontext() + + cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None or cast_buffer.numel() < size: + if ref is LARGEST_CASTED_WEIGHT[0]: + #If there is one giant weight we do not want both streams to + #allocate a buffer for it. It's up to the caster to get the other + #offload stream in this corner case + return None + if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): + #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + torch.cuda.synchronize() + del STREAM_CAST_BUFFERS[offload_stream] + del cast_buffer + #FIXME: This doesn't work in Aimdo because mempool cant clear cache + torch.cuda.empty_cache() + with wf_context: + cast_buffer = torch.empty((size), dtype=torch.int8, device=device) + STREAM_CAST_BUFFERS[offload_stream] = cast_buffer + + if size > LARGEST_CASTED_WEIGHT[1]: + LARGEST_CASTED_WEIGHT = (ref, size) + + return cast_buffer + +def reset_cast_buffers(): + global LARGEST_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) + for offload_stream in STREAM_CAST_BUFFERS: + offload_stream.synchronize() + STREAM_CAST_BUFFERS.clear() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.empty_cache() + def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) if NUM_STREAMS == 0: @@ -1082,7 +1178,62 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): + +def cast_to_gathered(tensors, r, non_blocking=False, stream=None): + wf_context = nullcontext() + if stream is not None: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + with wf_context: + for tensor in tensors: + dest_view = dest_views.pop(0) + if tensor is None: + continue + dest_view.copy_(tensor, non_blocking=non_blocking) + + +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): + if hasattr(weight, "_v"): + #Unexpected usage patterns. There is no reason these don't work but they + #have no testing and no callers do this. + assert r is None + assert stream is None + + cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ]) + + if dtype is None: + dtype = weight._model_dtype + + r = torch.empty_like(weight, dtype=dtype, device=device) + + signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) + if signature is not None: + raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) + v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] + if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): + weight._v_signature = signature + #Send it over + v_tensor.copy_(weight, non_blocking=non_blocking) + #always take a deep copy even if _v is good, as we have no reasonable point to unpin + #a non comfy weight + r.copy_(v_tensor) + comfy_aimdo.model_vbar.vbar_unpin(weight._v) + return r + + if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: + #Offloaded casting could skip this, however it would make the quantizations + #inconsistent between loaded and offloaded weights. So force the double casting + #that would happen in regular flow to make offload deterministic. + cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device) + cast_buffer.copy_(weight, non_blocking=non_blocking) + weight = cast_buffer + r.copy_(weight, non_blocking=non_blocking) + + return r + if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1101,10 +1252,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) with wf_context: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r @@ -1124,7 +1277,7 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) -PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) def discard_cuda_async_error(): try: @@ -1156,7 +1309,7 @@ def pin_memory(tensor): if not tensor.is_contiguous(): return False - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: return False @@ -1183,7 +1336,7 @@ def unpin_memory(tensor): return False ptr = tensor.data_ptr() - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes size_stored = PINNED_MEMORY.get(ptr, None) if size_stored is None: @@ -1504,6 +1657,16 @@ def supports_fp8_compute(device=None): return True +def supports_nvfp4_compute(device=None): + if not is_nvidia(): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): @@ -1536,8 +1699,11 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + if comfy.memory_management.aimdo_allocator is None: + #Pytorch 2.7 and earlier crashes if you try and empty_cache when mempools exist + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d26c690..b70c031bf 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -38,19 +38,7 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP - -def string_to_seed(data): - crc = 0xFFFFFFFF - for byte in data: - if isinstance(byte, str): - byte = ord(byte) - crc ^= byte - for _ in range(8): - if crc & 1: - crc = (crc >> 1) ^ 0xEDB88320 - else: - crc >>= 1 - return crc ^ 0xFFFFFFFF +import comfy_aimdo.model_vbar def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -123,6 +111,10 @@ def move_weight_functions(m, device): memory += f.move_to(device=device) return memory +def string_to_seed(data): + logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils") + return comfy.utils.string_to_seed(data) + class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key @@ -169,6 +161,11 @@ def get_key_weight(model, key): return weight, set_func, convert_func +def key_param_name_to_key(key, param): + if len(key) == 0: + return param + return "{}.{}".format(key, param) + class AutoPatcherEjector: def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): self.model = model @@ -212,6 +209,27 @@ class MemoryCounter: def decrement(self, used: int): self.value -= used +CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0) + +class LazyCastingParam(torch.nn.Parameter): + def __new__(cls, model, key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, model, key, tensor): + self.model = model + self.key = key + + @property + def device(self): + return CustomTorchDevice + + #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is + #then just a short lived thing in the safetensors serialization logic inside its big for loop over + #all weights getting garbage collected per-weight + def to(self, *args, **kwargs): + return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -269,6 +287,9 @@ class ModelPatcher: if not hasattr(self.model, 'model_offload_buffer_memory'): self.model.model_offload_buffer_memory = 0 + def is_dynamic(self): + return False + def model_size(self): if self.size > 0: return self.size @@ -284,6 +305,9 @@ class ModelPatcher: def lowvram_patch_counter(self): return self.model.lowvram_patch_counter + def get_free_memory(self, device): + return comfy.model_management.get_free_memory(device) + def clone(self): n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} @@ -611,14 +635,14 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): - if key not in self.patches: - return - + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): weight, set_func, convert_func = get_key_weight(self.model, key) + if key not in self.patches: + return weight + inplace_update = self.weight_inplace_update or inplace_update - if key not in self.backup: + if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) @@ -631,13 +655,15 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) + if return_weight: + return out_weight + elif inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -654,7 +680,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self): + def _load_list(self, prio_comfy_cast_weights=False): loading = [] for n, m in self.model.named_modules(): params = [] @@ -681,7 +707,8 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - loading.append((module_offload_mem, module_mem, n, m, params)) + prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () + loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -718,6 +745,7 @@ class ModelPatcher: continue cast_weight = self.force_cast_weights + m.comfy_force_cast_weights = self.force_cast_weights if lowvram_weight: if hasattr(m, "comfy_cast_weights"): m.weight_function = [] @@ -772,7 +800,7 @@ class ModelPatcher: continue for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) if comfy.model_management.is_device_cuda(device_to): @@ -788,13 +816,14 @@ class ModelPatcher: n = x[1] params = x[3] for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) + usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else "" if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: - logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: self.model.to(device_to) @@ -893,7 +922,7 @@ class ModelPatcher: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: move_weight = True for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) bk = self.backup.get(key, None) if bk is not None: if not lowvram_possible: @@ -944,7 +973,7 @@ class ModelPatcher: logging.debug("freed {}".format(n)) for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) self.model.model_lowvram = True @@ -982,6 +1011,9 @@ class ModelPatcher: return self.model.model_loaded_weight_memory - current_used + def partially_unload_ram(self, ram_to_unload): + pass + def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) @@ -1315,10 +1347,10 @@ class ModelPatcher: key, original_weights=original_weights) del original_weights[key] if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) comfy.utils.copy_to_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) + set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key)) if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # TODO: disable caching if not enough system RAM to do so target_device = self.offload_device @@ -1353,7 +1385,249 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model.diffusion_model.state_dict() + for k, v in unet_state_dict.items(): + op_keys = k.rsplit('.', 1) + if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: + continue + try: + op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + except: + continue + if not op or not hasattr(op, "comfy_cast_weights") or \ + (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + continue + key = "diffusion_model." + k + unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + return self.model.state_dict_for_saving(unet_state_dict) + def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) +class ModelPatcherDynamic(ModelPatcher): + + def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False): + if load_device is not None and comfy.model_management.is_device_cpu(load_device): + #reroute to default MP for CPUs + return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update) + return super().__new__(cls) + + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): + super().__init__(model, load_device, offload_device, size, weight_inplace_update) + #this is now way more dynamic and we dont support the same base model for both Dynamic + #and non-dynamic patchers. + if hasattr(self.model, "model_loaded_weight_memory"): + del self.model.model_loaded_weight_memory + if not hasattr(self.model, "dynamic_vbars"): + self.model.dynamic_vbars = {} + assert load_device is not None + + def is_dynamic(self): + return True + + def _vbar_get(self, create=False): + if self.load_device == torch.device("cpu"): + return None + vbar = self.model.dynamic_vbars.get(self.load_device, None) + if create and vbar is None: + # x10. We dont know what model defined type casts we have in the vbar, but virtual address + # space is pretty free. This will cover someone casting an entire model from FP4 to FP32 + # with some left over. + vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index) + self.model.dynamic_vbars[self.load_device] = vbar + return vbar + + def loaded_size(self): + vbar = self._vbar_get() + if vbar is None: + return 0 + return vbar.loaded_size() + + def get_free_memory(self, device): + #NOTE: on high condition / batch counts, estimate should have already vacated + #all non-dynamic models so this is safe even if its not 100% true that this + #would all be avaiable for inference use. + return comfy.model_management.get_total_memory(device) - self.model_size() + + #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. + + def pin_weight_to_device(self, key): + raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading") + + def unpin_weight(self, key): + raise RuntimeError("unpin_weight invalid for dymamic weight loading") + + def unpin_all_weights(self): + self.partially_unload_ram(1e32) + + def memory_required(self, input_shape): + #Pad this significantly. We are trying to get away from precise estimates. This + #estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you + #use all ModelPatcherDynamic this is ignored and its all done dynamically. + return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) + + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): + + #Force patching doesn't make sense in Dynamic loading, as you dont know what does and + #doesn't need to be forced at this stage. The only thing you could do would be patch + #it all on CPU which consumes huge RAM. + assert not force_patch_weights + + #Full load doesn't make sense as we dont actually have any loader capability here and + #now. + assert not full_load + + assert device_to == self.load_device + + num_patches = 0 + allocated_size = 0 + + with self.use_ejected(): + self.unpatch_hooks() + + vbar = self._vbar_get(create=True) + if vbar is not None: + vbar.prioritize() + + #We have way more tools for acceleration on comfy weight offloading, so always + #prioritize the non-comfy weights (note the order reverse). + loading = self._load_list(prio_comfy_cast_weights=True) + loading.sort(reverse=True) + + for x in loading: + _, _, _, n, m, params = x + + def set_dirty(item, dirty): + if dirty or not hasattr(item, "_v_signature"): + item._v_signature = None + + def setup_param(self, m, n, param_key): + nonlocal num_patches + key = key_param_name_to_key(n, param_key) + + weight_function = [] + + weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return 0 + if key in self.patches: + setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + num_patches += 1 + else: + setattr(m, param_key + "_lowvram_function", None) + + if key in self.weight_wrapper_patches: + weight_function.extend(self.weight_wrapper_patches[key]) + setattr(m, param_key + "_function", weight_function) + geometry = weight + if not isinstance(weight, QuantizedTensor): + model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype) + weight._model_dtype = model_dtype + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + return comfy.memory_management.vram_aligned_size(geometry) + + if hasattr(m, "comfy_cast_weights"): + m.comfy_cast_weights = True + m.pin_failed = False + m.seed_key = n + set_dirty(m, dirty) + + v_weight_size = 0 + v_weight_size += setup_param(self, m, n, "weight") + v_weight_size += setup_param(self, m, n, "bias") + + if vbar is not None and not hasattr(m, "_v"): + m._v = vbar.alloc(v_weight_size) + allocated_size += v_weight_size + + else: + for param in params: + key = key_param_name_to_key(n, param) + weight, _, _ = get_key_weight(self.model, key) + weight.seed_key = key + set_dirty(weight, dirty) + geometry = weight + model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype) + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + weight_size = geometry.numel() * geometry.element_size() + if vbar is not None and not hasattr(weight, "_v"): + weight._v = vbar.alloc(weight_size) + weight._model_dtype = model_dtype + allocated_size += weight_size + + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + + self.model.device = device_to + self.model.current_weight_patches_uuid = self.patches_uuid + + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + #These are all super dangerous. Who knows what the custom nodes actually do here... + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + + self.apply_hooks(self.forced_hooks, force_apply=True) + + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): + assert not force_patch_weights #See above + assert self.load_device != torch.device("cpu") + + vbar = self._vbar_get() + return 0 if vbar is None else vbar.free_memory(memory_to_free) + + def partially_unload_ram(self, ram_to_unload): + loading = self._load_list(prio_comfy_cast_weights=True) + for x in loading: + _, _, _, _, m, _ = x + ram_to_unload -= comfy.pinned_memory.unpin_memory(m) + if ram_to_unload <= 0: + return + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + #This isn't used by the core at all and can only be to load a model out of + #the control of proper model_managment. If you are a custom node author reading + #this, the correct pattern is to call load_models_gpu() to get a proper + #managed load of your model. + assert not load_weights + return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights) + + def unpatch_model(self, device_to=None, unpatch_weights=True): + super().unpatch_model(device_to=None, unpatch_weights=False) + + if unpatch_weights: + self.partially_unload_ram(1e32) + self.partially_unload(None) + + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + assert not force_patch_weights #See above + with self.use_ejected(skip_and_inject_on_exit_only=True): + dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid) + + self.unpatch_model(self.offload_device, unpatch_weights=False) + self.patch_model(load_weights=False) + + try: + self.load(device_to, dirty=dirty) + except Exception as e: + self.detach() + raise e + #ModelPatcher::partially_load returns a number on what got loaded but + #nothing in core uses this and we have no data in the Dynamic world. Hit + #the custom node devs with a None rather than a 0 that would mislead any + #logic they might have. + return None + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + assert False #Should be unreachable - we dont ever cache in the new implementation + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup") + + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: + pass + +CoreModelPatcher = ModelPatcher diff --git a/comfy/ops.py b/comfy/ops.py index 16889bb82..53c5e4dc3 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -19,10 +19,16 @@ import torch import logging import comfy.model_management -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import comfy.float import comfy.rmsnorm import json +import comfy.memory_management +import comfy.pinned_memory +import comfy.utils + +import comfy_aimdo.model_vbar +import comfy_aimdo.torch def run_every_op(): if torch.compiler.is_compiling(): @@ -72,14 +78,122 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): + offload_stream = None + xfer_dest = None + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) + + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + if signature is not None: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + + if not resident: + cast_dest = None + + xfer_source = [ s.weight, s.bias ] + + pin = comfy.pinned_memory.get_pin(s) + if pin is not None: + xfer_source = [ pin ] + + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break + + dest_size = comfy.memory_management.vram_aligned_size(xfer_source) + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None: + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + offload_stream = comfy.model_management.get_offload_stream(device) + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None + + if signature is None and pin is None: + comfy.pinned_memory.pin_memory(s) + pin = comfy.pinned_memory.get_pin(s) + else: + pin = None + + if pin is not None: + comfy.model_management.cast_to_gathered(xfer_source, pin) + xfer_source = [ pin ] + #send it over + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + comfy.model_management.sync_stream(device, offload_stream) + + if cast_dest is not None: + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), + comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + if post_cast is not None: + post_cast.copy_(pre_cast) + xfer_dest = cast_dest + + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + weight = params[0] + bias = params[1] + + def post_cast(s, param_key, x, dtype, resident, update_weight): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + fns = getattr(s, param_key + "_function", []) + + orig = x + + def to_dequant(tensor, dtype): + tensor = tensor.to(dtype=dtype) + if isinstance(tensor, QuantizedTensor): + tensor = tensor.dequantize() + return tensor + + if orig.dtype != dtype or len(fns) > 0: + x = to_dequant(x, dtype) + if not resident and lowvram_fn is not None: + x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) + #FIXME: this is not accurate, we need to be sensitive to the compute dtype + x = lowvram_fn(x) + if (isinstance(orig, QuantizedTensor) and + (orig.dtype == dtype and len(fns) == 0 or update_weight)): + seed = comfy.utils.string_to_seed(s.seed_key) + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + if orig.dtype == dtype and len(fns) == 0: + #The layer actually wants our freshly saved QT + x = y + else: + y = x + if update_weight: + orig.copy_(y) + for f in fns: + x = f(x) + return x + + update_weight = signature is not None + + weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + if s.bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + s._v_signature=signature + + #FIXME: weird offload return protocol + return weight, bias, (offload_stream, device if signature is not None else None, None) + + +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # will add async-offload support to your cast and improve performance. if input is not None: if dtype is None: if isinstance(input, QuantizedTensor): - dtype = input._layout_params["orig_dtype"] + dtype = input.params.orig_dtype else: dtype = input.dtype if bias_dtype is None: @@ -87,22 +201,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + if hasattr(s, "_v"): + return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) else: offload_stream = None - non_blocking = comfy.model_management.device_supports_non_blocking(device) + bias = None + weight = None + + if offload_stream is not None and not args.cuda_malloc: + cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ]) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + #The streams can be uneven in buffer capability and reject us. Retry to get the other stream + if cast_buffer is None: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + weight = params[0] + bias = params[1] weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight) - bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias) comfy.model_management.sync_stream(device, offload_stream) @@ -110,6 +240,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_a = weight if s.bias is not None: + bias = bias.to(dtype=bias_dtype) for f in s.bias_function: bias = f(bias) @@ -131,14 +262,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return os, weight_a, bias_a = offload_stream + device=None + #FIXME: This is not good RTTI + if not isinstance(weight_a, torch.Tensor): + comfy_aimdo.model_vbar.vbar_unpin(s._v) + device = weight_a if os is None: return - if weight_a is not None: - device = weight_a.device - else: - if bias_a is None: - return - device = bias_a.device + if device is None: + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device os.wait_stream(comfy.model_management.current_stream(device)) @@ -149,6 +286,57 @@ class CastWeightBiasOp: class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): + + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + super().__init__(in_features, out_features, bias, device, dtype) + return + + # Issue is with `torch.empty` still reserving the full memory for the layer. + # Windows doesn't over-commit memory so without this, We are momentarily commit + # charged for the weight even though we might zero-copy it when we load the + # state dict. If the commit charge exceeds the ceiling we can destabilize the + # system. + torch.nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.bias = None + self.comfy_need_lazy_init_bias=bias + self.weight_comfy_model_dtype = dtype + self.bias_comfy_model_dtype = dtype + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if not comfy.model_management.WINDOWS or not enables_dynamic_vram(): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + prefix_len = len(prefix) + for k,v in state_dict.items(): + if k[prefix_len:] == "weight": + if not assign_to_params_buffers: + v = v.clone() + self.weight = torch.nn.Parameter(v, requires_grad=False) + elif k[prefix_len:] == "bias" and v is not None: + if not assign_to_params_buffers: + v = v.clone() + self.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + #Reconcile default construction of the weight if its missing. + if self.weight is None: + v = torch.zeros(self.in_features, self.out_features) + self.weight = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"weight") + if self.bias is None and self.comfy_need_lazy_init_bias: + v = torch.zeros(self.out_features,) + self.bias = torch.nn.Parameter(v, requires_grad=False) + missing_keys.append(prefix+"bias") + + def reset_parameters(self): return None @@ -203,7 +391,9 @@ class disable_weight_init: def reset_parameters(self): return None - def _conv_forward(self, input, weight, bias, *args, **kwargs): + def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs): + if autopad == "causal_zero": + weight = weight[:, :, -input.shape[2]:, :, :] if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16): out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True) if bias is not None: @@ -212,15 +402,15 @@ class disable_weight_init: else: return super()._conv_forward(input, weight, bias, *args, **kwargs) - def forward_comfy_cast_weights(self, input): + def forward_comfy_cast_weights(self, input, autopad=None): weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - x = self._conv_forward(input, weight, bias) + x = self._conv_forward(input, weight, bias, autopad=autopad) uncast_bias_weight(self, weight, bias, offload_stream) return x def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -412,26 +602,34 @@ def fp8_linear(self, input): return None input_dtype = input.dtype + input_shape = input.shape + tensor_3d = input.ndim == 3 - if input.ndim == 3 or input.ndim == 2: - w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) + if tensor_3d: + input = input.reshape(-1, input_shape[2]) - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) + if input.ndim != 2: + return None + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - # Wrap weight in QuantizedTensor - this enables unified dispatch - # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! - layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} - quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) - o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + input_fp8 = input.to(dtype).contiguous() + layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape)) + quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input) - uncast_bias_weight(self, w, bias, offload_stream) - return o + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape)) + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - return None + uncast_bias_weight(self, w, bias, offload_stream) + if tensor_3d: + o = o.reshape((input_shape[0], input_shape[1], w.shape[0])) + + return o class fp8_ops(manual_cast): class Linear(manual_cast.Linear): @@ -477,14 +675,20 @@ if CUBLAS_IS_AVAILABLE: # ============================================================================== # Mixed Precision Operations # ============================================================================== -from .quant_ops import QuantizedTensor, QUANT_ALGOS +from .quant_ops import ( + QuantizedTensor, + QUANT_ALGOS, + TensorCoreFP8Layout, + get_layout_class, +) -def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm + _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): def __init__( @@ -497,21 +701,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec ) -> None: super().__init__() - if dtype is None: - dtype = MixedPrecisionOps._compute_dtype - - self.factory_kwargs = {"device": device, "dtype": dtype} + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features - self._has_bias = bias + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False def reset_parameters(self): return None + def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): + key = f"{prefix}{param_name}" + value = state_dict.pop(key, None) + if value is not None: + value = value.to(device=device) + if dtype is not None: + value = value.view(dtype=dtype) + manually_loaded_keys.append(key) + return value + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -520,7 +736,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec weight_key = f"{prefix}weight" weight = state_dict.pop(weight_key, None) if weight is None: - raise ValueError(f"Missing weight for layer {layer_name}") + logging.warning(f"Missing weight for layer {layer_name}") + return manually_loaded_keys = [weight_key] @@ -529,49 +746,61 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec layer_conf = json.loads(layer_conf.numpy().tobytes()) if layer_conf is None: - dtype = self.factory_kwargs["dtype"] - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False) - if dtype != MixedPrecisionOps._compute_dtype: - self.comfy_cast_weights = True - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) - else: - self.register_parameter("bias", None) + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) + self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) if not self._full_precision_mm: - self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + self._full_precision_mm = self._full_precision_mm_config + + if self.quant_format in MixedPrecisionOps._disabled: + self._full_precision_mm = True if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(self.layout_type) - weight_scale_key = f"{prefix}weight_scale" - scale = state_dict.pop(weight_scale_key, None) - if scale is not None: - scale = scale.to(device) - layout_params = { - 'scale': scale, - 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), - } + # Load format-specific parameters + if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: + # FP8: single tensor scale + scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - if scale is not None: - manually_loaded_keys.append(weight_scale_key) + params = layout_cls.Params( + scale=scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + + elif self.quant_format == "nvfp4": + # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) + tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.float8_e4m3fn) + + if tensor_scale is None or block_scale is None: + raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") + + params = layout_cls.Params( + scale=tensor_scale, + block_scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + else: + raise ValueError(f"Unsupported quantization format: {self.quant_format}") self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), requires_grad=False ) - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype)) - else: - self.register_parameter("bias", None) - for param_name in qconfig["parameters"]: + if param_name in {"weight_scale", "weight_scale_2"}: + continue # Already handled above + param_key = f"{prefix}{param_name}" _v = state_dict.pop(param_key, None) if _v is None: @@ -586,20 +815,36 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec missing_keys.remove(key) def state_dict(self, *args, destination=None, prefix="", **kwargs): - sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) + if destination is not None: + sd = destination + else: + sd = {} + + if self.bias is not None: + sd["{}bias".format(prefix)] = self.bias + if isinstance(self.weight, QuantizedTensor): - sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] + quant_conf = {"format": self.quant_format} - if self._full_precision_mm: + if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + + input_scale = getattr(self, 'input_scale', None) + if input_scale is not None: + sd["{}input_scale".format(prefix)] = input_scale + else: + sd["{}weight".format(prefix)] = self.weight return sd def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + def forward_comfy_cast_weights(self, input, compute_dtype=None): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -607,12 +852,36 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def forward(self, input, *args, **kwargs): run_every_op() - if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(input, *args, **kwargs) + input_shape = input.shape + reshaped_3d = False + #If cast needs to apply lora, it should be done in the compute dtype + compute_dtype = input.dtype + if (getattr(self, 'layout_type', None) is not None and - not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) - return self._forward(input, self.weight, self.bias) + not isinstance(input, QuantizedTensor) and not self._full_precision_mm and + not getattr(self, 'comfy_force_cast_weights', False) and + len(self.weight_function) == 0 and len(self.bias_function) == 0): + + # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) + input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input + + # Fall back to non-quantized for non-2D tensors + if input_reshaped.ndim == 2: + reshaped_3d = input.ndim == 3 + # dtype is now implicit in the layout class + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) + + + output = self.forward_comfy_cast_weights(input, compute_dtype) + + # Reshape output back to 3D if input was 3D + if reshaped_3d: + output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) + + return output def convert_weight(self, weight, inplace=False, **kwargs): if isinstance(weight, QuantizedTensor): @@ -622,7 +891,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + # dtype is now implicit in the layout class + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype) else: weight = weight.to(self.weight.dtype) if return_weight: @@ -649,10 +919,17 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") - return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) + disabled = set() + if not nvfp4_compute: + disabled.add("nvfp4") + if not fp8_compute: + disabled.add("float8_e4m3fn") + disabled.add("float8_e5m2") + return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( fp8_compute and diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py new file mode 100644 index 000000000..8acc327a7 --- /dev/null +++ b/comfy/pinned_memory.py @@ -0,0 +1,29 @@ +import torch +import comfy.model_management +import comfy.memory_management + +from comfy.cli_args import args + +def get_pin(module): + return getattr(module, "_pin", None) + +def pin_memory(module): + if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + return + #FIXME: This is a RAM cache trigger event + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + pin = torch.empty((size,), dtype=torch.uint8) + if comfy.model_management.pin_memory(pin): + module._pin = pin + else: + module.pin_failed = True + return False + return True + +def unpin_memory(module): + if get_pin(module) is None: + return 0 + size = module._pin.numel() * module._pin.element_size() + comfy.model_management.unpin_memory(module._pin) + del module._pin + return size diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd96541d7..15a4f457b 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,580 +1,174 @@ import torch import logging -from typing import Tuple, Dict + +try: + import comfy_kitchen as ck + from comfy_kitchen.tensor import ( + QuantizedTensor, + QuantizedLayout, + TensorCoreFP8Layout as _CKFp8Layout, + TensorCoreNVFP4Layout as _CKNvfp4Layout, + register_layout_op, + register_layout_class, + get_layout_class, + ) + _CK_AVAILABLE = True + if torch.version.cuda is None: + ck.registry.disable("cuda") + else: + cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) + if cuda_version < (13,): + ck.registry.disable("cuda") + logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") + + ck.registry.disable("triton") + for k, v in ck.list_backends().items(): + logging.info(f"Found comfy_kitchen backend {k}: {v}") +except ImportError as e: + logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.") + _CK_AVAILABLE = False + + class QuantizedTensor: + pass + + class _CKFp8Layout: + pass + + class _CKNvfp4Layout: + pass + + def register_layout_class(name, cls): + pass + + def get_layout_class(name): + return None + import comfy.float -_LAYOUT_REGISTRY = {} -_GENERIC_UTILS = {} - - -def register_layout_op(torch_op, layout_type): - """ - Decorator to register a layout-specific operation handler. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) - layout_type: Layout class (e.g., TensorCoreFP8Layout) - Example: - @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) - def fp8_linear(func, args, kwargs): - # FP8-specific linear implementation - ... - """ - def decorator(handler_func): - if torch_op not in _LAYOUT_REGISTRY: - _LAYOUT_REGISTRY[torch_op] = {} - _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func - return handler_func - return decorator - - -def register_generic_util(torch_op): - """ - Decorator to register a generic utility that works for all layouts. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) - - Example: - @register_generic_util(torch.ops.aten.detach.default) - def generic_detach(func, args, kwargs): - # Works for any layout - ... - """ - def decorator(handler_func): - _GENERIC_UTILS[torch_op] = handler_func - return handler_func - return decorator - - -def _get_layout_from_args(args): - for arg in args: - if isinstance(arg, QuantizedTensor): - return arg._layout_type - elif isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, QuantizedTensor): - return item._layout_type - return None - - -def _move_layout_params_to_device(params, device): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.to(device=device) - else: - new_params[k] = v - return new_params - - -def _copy_layout_params(params): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.clone() - else: - new_params[k] = v - return new_params - -def _copy_layout_params_inplace(src, dst, non_blocking=False): - for k, v in src.items(): - if isinstance(v, torch.Tensor): - dst[k].copy_(v, non_blocking=non_blocking) - else: - dst[k] = v - -class QuantizedLayout: - """ - Base class for quantization layouts. - - A layout encapsulates the format-specific logic for quantization/dequantization - and provides a uniform interface for extracting raw tensors needed for computation. - - New quantization formats should subclass this and implement the required methods. - """ - @classmethod - def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: - raise NotImplementedError(f"{cls.__name__} must implement quantize()") - - @staticmethod - def dequantize(qdata, **layout_params) -> torch.Tensor: - raise NotImplementedError("TensorLayout must implement dequantize()") - - @classmethod - def get_plain_tensors(cls, qtensor) -> torch.Tensor: - raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") - - -class QuantizedTensor(torch.Tensor): - """ - Universal quantized tensor that works with any layout. - - This tensor subclass uses a pluggable layout system to support multiple - quantization formats (FP8, INT4, INT8, etc.) without code duplication. - - The layout_type determines format-specific behavior, while common operations - (detach, clone, to) are handled generically. - - Attributes: - _qdata: The quantized tensor data - _layout_type: Layout class (e.g., TensorCoreFP8Layout) - _layout_params: Dict with layout-specific params (scale, zero_point, etc.) - """ - - @staticmethod - def __new__(cls, qdata, layout_type, layout_params): - """ - Create a quantized tensor. - - Args: - qdata: The quantized data tensor - layout_type: Layout class (subclass of QuantizedLayout) - layout_params: Dict with layout-specific parameters - """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) - - def __init__(self, qdata, layout_type, layout_params): - self._qdata = qdata - self._layout_type = layout_type - self._layout_params = layout_params - - def __repr__(self): - layout_name = self._layout_type - param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) - return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - - @property - def layout_type(self): - return self._layout_type - - def __tensor_flatten__(self): - """ - Tensor flattening protocol for proper device movement. - """ - inner_tensors = ["_qdata"] - ctx = { - "layout_type": self._layout_type, - } - - tensor_params = {} - non_tensor_params = {} - for k, v in self._layout_params.items(): - if isinstance(v, torch.Tensor): - tensor_params[k] = v - else: - non_tensor_params[k] = v - - ctx["tensor_param_keys"] = list(tensor_params.keys()) - ctx["non_tensor_params"] = non_tensor_params - - for k, v in tensor_params.items(): - attr_name = f"_layout_param_{k}" - object.__setattr__(self, attr_name, v) - inner_tensors.append(attr_name) - - return inner_tensors, ctx - - @staticmethod - def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): - """ - Tensor unflattening protocol for proper device movement. - Reconstructs the QuantizedTensor after device movement. - """ - layout_type = ctx["layout_type"] - layout_params = dict(ctx["non_tensor_params"]) - - for key in ctx["tensor_param_keys"]: - attr_name = f"_layout_param_{key}" - layout_params[key] = inner_tensors[attr_name] - - return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) - - @classmethod - def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': - qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) - return cls(qdata, layout_type, layout_params) - - def dequantize(self) -> torch.Tensor: - return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - - # Step 1: Check generic utilities first (detach, clone, to, etc.) - if func in _GENERIC_UTILS: - return _GENERIC_UTILS[func](func, args, kwargs) - - # Step 2: Check layout-specific handlers (linear, matmul, etc.) - layout_type = _get_layout_from_args(args) - if layout_type and func in _LAYOUT_REGISTRY: - handler = _LAYOUT_REGISTRY[func].get(layout_type) - if handler: - return handler(func, args, kwargs) - - # Step 3: Fallback to dequantization - if isinstance(args[0] if args else None, QuantizedTensor): - logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") - return cls._dequant_and_fallback(func, args, kwargs) - - @classmethod - def _dequant_and_fallback(cls, func, args, kwargs): - def dequant_arg(arg): - if isinstance(arg, QuantizedTensor): - return arg.dequantize() - elif isinstance(arg, (list, tuple)): - return type(arg)(dequant_arg(a) for a in arg) - return arg - - new_args = dequant_arg(args) - new_kwargs = dequant_arg(kwargs) - return func(*new_args, **new_kwargs) - - def data_ptr(self): - return self._qdata.data_ptr() - - def is_pinned(self): - return self._qdata.is_pinned() - - def is_contiguous(self, *arg, **kwargs): - return self._qdata.is_contiguous(*arg, **kwargs) - - def storage(self): - return self._qdata.storage() - # ============================================================================== -# Generic Utilities (Layout-Agnostic Operations) +# FP8 Layouts with Comfy-Specific Extensions # ============================================================================== -def _create_transformed_qtensor(qt, transform_fn): - new_data = transform_fn(qt._qdata) - new_params = _copy_layout_params(qt._layout_params) - return QuantizedTensor(new_data, qt._layout_type, new_params) +class _TensorCoreFP8LayoutBase(_CKFp8Layout): + FP8_DTYPE = None # Must be overridden in subclass - -def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): - if target_layout is not None and target_layout != torch.strided: - logging.warning( - f"QuantizedTensor: layout change requested to {target_layout}, " - f"but not supported. Ignoring layout." - ) - - # Handle device transfer - current_device = qt._qdata.device - if target_device is not None: - # Normalize device for comparison - if isinstance(target_device, str): - target_device = torch.device(target_device) - if isinstance(current_device, str): - current_device = torch.device(current_device) - - if target_device != current_device: - logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") - new_q_data = qt._qdata.to(device=target_device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - if target_dtype is not None: - new_params["orig_dtype"] = target_dtype - new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) - logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") - return new_qt - - logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") - return qt - - -@register_generic_util(torch.ops.aten.detach.default) -def generic_detach(func, args, kwargs): - """Detach operation - creates a detached copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.detach()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.clone.default) -def generic_clone(func, args, kwargs): - """Clone operation - creates a deep copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.clone()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._to_copy.default) -def generic_to_copy(func, args, kwargs): - """Device/dtype transfer operation - handles .to(device) calls.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - op_name="_to_copy" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype_layout) -def generic_to_dtype_layout(func, args, kwargs): - """Handle .to(device) calls using the dtype_layout variant.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - target_layout=kwargs.get('layout', None), - op_name="to" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.copy_.default) -def generic_copy_(func, args, kwargs): - qt_dest = args[0] - src = args[1] - non_blocking = args[2] if len(args) > 2 else False - if isinstance(qt_dest, QuantizedTensor): - if isinstance(src, QuantizedTensor): - # Copy from another quantized tensor - qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) - qt_dest._layout_type = src._layout_type - orig_dtype = qt_dest._layout_params["orig_dtype"] - _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) - qt_dest._layout_params["orig_dtype"] = orig_dtype - else: - # Copy from regular tensor - just copy raw data - qt_dest._qdata.copy_(src) - return qt_dest - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype) -def generic_to_dtype(func, args, kwargs): - """Handle .to(dtype) calls - dtype conversion only.""" - src = args[0] - if isinstance(src, QuantizedTensor): - # For dtype-only conversion, just change the orig_dtype, no real cast is needed - target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') - src._layout_params["orig_dtype"] = target_dtype - return src - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) -def generic_has_compatible_shallow_copy_type(func, args, kwargs): - return True - - -@register_generic_util(torch.ops.aten.empty_like.default) -def generic_empty_like(func, args, kwargs): - """Empty_like operation - creates an empty tensor with the same quantized structure.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - # Create empty tensor with same shape and dtype as the quantized data - hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"]) - new_qdata = torch.empty_like(qt._qdata, **kwargs) - - # Handle device transfer for layout params - target_device = kwargs.get('device', new_qdata.device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - - # Update orig_dtype if dtype is specified - new_params['orig_dtype'] = hp_dtype - - return QuantizedTensor(new_qdata, qt._layout_type, new_params) - return func(*args, **kwargs) - -# ============================================================================== -# FP8 Layout + Operation Handlers -# ============================================================================== -class TensorCoreFP8Layout(QuantizedLayout): - """ - Storage format: - - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) - - scale: Scalar tensor (float32) for dequantization - - orig_dtype: Original dtype before quantization (for casting back) - """ @classmethod - def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if cls.FP8_DTYPE is None: + raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE") + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) if isinstance(scale, str) and scale == "recalculate": - scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small tensor_info = torch.finfo(tensor.dtype) scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) - if scale is not None: - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is None: + scale = torch.ones((), device=tensor.device, dtype=torch.float32) + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + if stochastic_rounding > 0: if inplace_ops: tensor *= (1.0 / scale).to(tensor.dtype) else: tensor = tensor * (1.0 / scale).to(tensor.dtype) + qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding) else: - scale = torch.ones((), device=tensor.device, dtype=torch.float32) + qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE) + + params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape) + return qdata, params + + +class TensorCoreNVFP4Layout(_CKNvfp4Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + if scale is None or (isinstance(scale, str) and scale == "recalculate"): + scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX) + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape if stochastic_rounding > 0: - tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) else: - lp_amax = torch.finfo(dtype).max - torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) - tensor = tensor.to(dtype, memory_format=torch.contiguous_format) + qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding) - layout_params = { - 'scale': scale, - 'orig_dtype': orig_dtype - } - return tensor, layout_params + params = cls.Params( + scale=scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + block_scale=block_scale, + ) + return qdata, params - @staticmethod - def dequantize(qdata, scale, orig_dtype, **kwargs): - plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) - plain_tensor.mul_(scale) - return plain_tensor - @classmethod - def get_plain_tensors(cls, qtensor): - return qtensor._qdata, qtensor._layout_params['scale'] +class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e4m3fn + + +class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e5m2 + + +# Backward compatibility alias - default to E4M3 +TensorCoreFP8Layout = TensorCoreFP8E4M3Layout + + +# ============================================================================== +# Registry +# ============================================================================== + +register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) +register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) +register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) +register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) QUANT_ALGOS = { "float8_e4m3fn": { "storage_t": torch.float8_e4m3fn, "parameters": {"weight_scale", "input_scale"}, - "comfy_tensor_layout": "TensorCoreFP8Layout", + "comfy_tensor_layout": "TensorCoreFP8E4M3Layout", + }, + "float8_e5m2": { + "storage_t": torch.float8_e5m2, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreFP8E5M2Layout", + }, + "nvfp4": { + "storage_t": torch.uint8, + "parameters": {"weight_scale", "weight_scale_2", "input_scale"}, + "comfy_tensor_layout": "TensorCoreNVFP4Layout", + "group_size": 16, }, } -LAYOUTS = { - "TensorCoreFP8Layout": TensorCoreFP8Layout, -} +# ============================================================================== +# Re-exports for backward compatibility +# ============================================================================== -@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") -def fp8_linear(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - out_dtype = kwargs.get("out_dtype") - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - weight_t = plain_weight.t() - - tensor_2d = False - if len(plain_input.shape) == 2: - tensor_2d = True - plain_input = plain_input.unsqueeze(1) - - input_shape = plain_input.shape - if len(input_shape) != 3: - return None - - try: - output = torch._scaled_mm( - plain_input.reshape(-1, input_shape[2]).contiguous(), - weight_t, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - - if not tensor_2d: - output = output.reshape((-1, input_shape[1], weight.shape[0])) - - if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - output_scale = scale_a * scale_b - output_params = { - 'scale': output_scale, - 'orig_dtype': input_tensor._layout_params['orig_dtype'] - } - return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) - else: - return output - - except Exception as e: - raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - - # Case 2: DQ Fallback - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - if isinstance(input_tensor, QuantizedTensor): - input_tensor = input_tensor.dequantize() - - return torch.nn.functional.linear(input_tensor, weight, bias) - -def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - output = torch._scaled_mm( - plain_input.contiguous(), - plain_weight, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - return output - -@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") -def fp8_addmm(func, args, kwargs): - input_tensor = args[1] - weight = args[2] - bias = args[0] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - if isinstance(args[2], QuantizedTensor): - a[2] = args[2].dequantize() - - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout") -def fp8_mm(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") -@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") -def fp8_func(func, args, kwargs): - input_tensor = args[0] - if isinstance(input_tensor, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - ar = list(args) - ar[0] = plain_input - return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) - return func(*args, **kwargs) +__all__ = [ + "QuantizedTensor", + "QuantizedLayout", + "TensorCoreFP8Layout", + "TensorCoreFP8E4M3Layout", + "TensorCoreFP8E5M2Layout", + "TensorCoreNVFP4Layout", + "QUANT_ALGOS", + "register_layout_op", +] diff --git a/comfy/sample.py b/comfy/sample.py index 2f8f3a51c..a2a39b527 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None): return noises -def fix_empty_latent_channels(model, latent_image): +def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None): if latent_image.is_nested: return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels - if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + if torch.count_nonzero(latent_image) == 0: + if latent_format.latent_channels != latent_image.shape[1]: + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + if downscale_ratio_spacial is not None: + if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: + ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio + latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled") + if latent_format.latent_dimensions == 3 and latent_image.ndim == 4: latent_image = latent_image.unsqueeze(2) return latent_image diff --git a/comfy/samplers.py b/comfy/samplers.py index 1989ef107..8b9782956 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: import torch from functools import partial import collections -from comfy import model_management import math import logging import comfy.sampler_helpers @@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(x_in.device) + free_memory = model.current_patcher.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] diff --git a/comfy/sd.py b/comfy/sd.py index 7de7dd9c6..fd0ac85e7 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert +import comfy.weight_adapter import yaml import math import os @@ -57,6 +58,7 @@ import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 import comfy.text_encoders.jina_clip_2 import comfy.text_encoders.newbie +import comfy.text_encoders.anima import comfy.model_patcher import comfy.lora @@ -100,6 +102,105 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): return (new_modelpatcher, new_clip) +def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip): + """ + Load LoRA in bypass mode without modifying base model weights. + + Instead of patching weights, this injects the LoRA computation into the + forward pass: output = base_forward(x) + lora_path(x) + + Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches. + + This is useful for training and when model weights are offloaded. + """ + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + + logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries") + + lora = comfy.lora_convert.convert_lora(lora) + loaded = comfy.lora.load_lora(lora, key_map) + + logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries") + + # Separate adapters (for bypass) from other patches (for regular patching) + bypass_patches = {} # WeightAdapterBase instances -> bypass mode + regular_patches = {} # diff, set, bias patches -> regular weight patching + + for key, patch_data in loaded.items(): + if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase): + bypass_patches[key] = patch_data + else: + regular_patches[key] = patch_data + + logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches") + + k = set() + k1 = set() + + if model is not None: + new_modelpatcher = model.clone() + + # Apply regular patches (bias diff, weight diff, etc.) via normal patching + if regular_patches: + patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model) + k.update(patched_keys) + + # Apply adapter patches via bypass injection + manager = comfy.weight_adapter.BypassInjectionManager() + model_sd_keys = set(new_modelpatcher.model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in model_sd_keys: + manager.add_adapter(key, adapter, strength=strength_model) + k.add(key) + else: + logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}") + + injections = manager.create_injections(new_modelpatcher.model) + + if manager.get_hook_count() > 0: + new_modelpatcher.set_injections("bypass_lora", injections) + else: + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + + # Apply regular patches to clip + if regular_patches: + patched_keys = new_clip.add_patches(regular_patches, strength_clip) + k1.update(patched_keys) + + # Apply adapter patches via bypass injection + clip_manager = comfy.weight_adapter.BypassInjectionManager() + clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in clip_sd_keys: + clip_manager.add_adapter(key, adapter, strength=strength_clip) + k1.add(key) + + clip_injections = clip_manager.create_injections(new_clip.cond_stage_model) + if clip_manager.get_hook_count() > 0: + new_clip.patcher.set_injections("bypass_lora", clip_injections) + else: + new_clip = None + + for x in loaded: + if (x not in k) and (x not in k1): + patch_data = loaded[x] + patch_type = type(patch_data).__name__ + if isinstance(patch_data, tuple): + patch_type = f"tuple({patch_data[0]})" + logging.warning(f"NOT LOADED: {x} (type={patch_type})") + + return (new_modelpatcher, new_clip) + + class CLIP: def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): if no_init: @@ -127,8 +228,10 @@ class CLIP: self.cond_stage_model.to(offload_device) logging.warning("Had to shift TE back.") + model_management.archive_model_dtypes(self.cond_stage_model) + self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -218,7 +321,7 @@ class CLIP: if unprojected: self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) @@ -266,7 +369,7 @@ class CLIP: if return_pooled == "unprojected": self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] @@ -288,8 +391,18 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: + can_assign = self.patcher.is_dynamic() + self.cond_stage_model.can_assign_sd = can_assign + + # The CLIP models are a pretty complex web of wrappers and its + # a bit of an API change to plumb this all the way through. + # So spray paint the model with this flag that the loading + # nn.Module can then inspect for itself. + for m in self.cond_stage_model.modules(): + m.can_assign_sd = can_assign + return self.cond_stage_model.load_sd(sd) def get_sd(self): @@ -299,8 +412,11 @@ class CLIP: sd_clip[k] = sd_tokenizer[k] return sd_clip - def load_model(self): - model_management.load_model_gpu(self.patcher) + def load_model(self, tokens={}): + memory_used = 0 + if hasattr(self.cond_stage_model, "memory_estimation_function"): + memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) return self.patcher def get_key_patches(self): @@ -476,8 +592,8 @@ class VAE: self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config) self.latent_channels = 128 self.latent_dim = 3 - self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.upscale_index_formula = (8, 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) @@ -632,14 +748,13 @@ class VAE: self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) - if self.latent_channels == 48: # Wan 2.2 + if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_input = self.process_output = lambda image: image self.process_output = lambda image: image self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical @@ -662,12 +777,7 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - logging.warning("Missing VAE keys {}".format(m)) - - if len(u) > 0: - logging.debug("Leftover VAE keys {}".format(u)) + model_management.archive_model_dtypes(self.first_stage_model) if device is None: device = model_management.vae_device() @@ -679,7 +789,18 @@ class VAE: self.first_stage_model.to(self.vae_dtype) self.output_device = model_management.intermediate_device() - self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + mp = comfy.model_patcher.CoreModelPatcher + if self.disable_offload: + mp = comfy.model_patcher.ModelPatcher + self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + if len(m) > 0: + logging.warning("Missing VAE keys {}".format(m)) + + if len(u) > 0: + logging.debug("Leftover VAE keys {}".format(u)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) self.model_size() @@ -794,7 +915,7 @@ class VAE: try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) @@ -868,7 +989,7 @@ class VAE: try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) samples = None @@ -1011,6 +1132,7 @@ class CLIPType(Enum): KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 NEWBIE = 24 + FLUX2 = 25 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1041,7 +1163,10 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 - JINA_CLIP_2 = 18 + GEMMA_3_12B = 18 + JINA_CLIP_2 = 19 + QWEN3_8B = 20 + QWEN3_06B = 21 def detect_te_model(sd): @@ -1055,9 +1180,9 @@ def detect_te_model(sd): return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] - if weight.shape[-1] == 4096: + if weight.shape[0] == 10240: return TEModel.T5_XXL - elif weight.shape[-1] == 2048: + elif weight.shape[0] == 5120: return TEModel.T5_XL if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD @@ -1067,6 +1192,8 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.47.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B @@ -1083,6 +1210,10 @@ def detect_te_model(sd): return TEModel.QWEN3_4B elif weight.shape[0] == 2048: return TEModel.QWEN3_2B + elif weight.shape[0] == 4096: + return TEModel.QWEN3_8B + elif weight.shape[0] == 1024: + return TEModel.QWEN3_06B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1208,14 +1339,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) elif te_model == TEModel.QWEN3_4B: - clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer + if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2: + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b") + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer + else: + clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer elif te_model == TEModel.QWEN3_2B: clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer + elif te_model == TEModel.QWEN3_8B: + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b") + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B elif te_model == TEModel.JINA_CLIP_2: clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper + elif te_model == TEModel.QWEN3_06B: + clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer else: # clip_l if clip_type == CLIPType.SD3: @@ -1271,6 +1412,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.LTXV: + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer @@ -1305,7 +1450,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def model_detection_error_hint(path, state_dict): filename = os.path.basename(path) @@ -1393,7 +1538,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) @@ -1436,7 +1582,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded diffusion model directly to GPU") model_management.load_models_gpu([model_patcher], force_full_load=True) @@ -1528,13 +1673,14 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device) + if not model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) left_over = sd.keys() if len(left_over) > 0: logging.info("left over keys in diffusion model: {}".format(left_over)) - return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) - + return model_patcher def load_diffusion_model(unet_path, model_options={}): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) @@ -1565,9 +1711,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if metadata is None: metadata = {} - model_management.load_models_gpu(load_models, force_patch_weights=True) + model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k] diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c512ca5d0..9ecfc9c55 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -297,7 +297,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) def parse_parentheses(string): result = [] @@ -466,7 +466,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, start_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -479,8 +479,15 @@ class SDTokenizer: empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token if has_start_token: - self.tokens_start = 1 - self.start_token = empty[0] + if len(empty) > 0: + self.tokens_start = 1 + self.start_token = empty[0] + else: + self.tokens_start = 0 + self.start_token = start_token + if start_token is None: + logging.warning("WARNING: There's something wrong with your tokenizers.'") + if end_token is not None: self.end_token = end_token else: @@ -488,7 +495,7 @@ class SDTokenizer: self.end_token = empty[1] else: self.tokens_start = 0 - self.start_token = None + self.start_token = start_token if end_token is not None: self.end_token = end_token else: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1888f35ba..d25271d6e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -23,6 +23,7 @@ import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image +import comfy.text_encoders.anima from . import supported_models_base from . import latent_formats @@ -763,17 +764,31 @@ class Flux2(Flux): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * (unet_config['hidden_size'] / 2604) def get_model(self, state_dict, prefix="", device=None): out = model_base.Flux2(self, device=device) return out def clip_target(self, state_dict={}): - return None # TODO pref = self.text_encoder_key_prefix[0] - t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + if len(detect) > 0: + detect["model_type"] = "qwen3_4b" + return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer, comfy.text_encoders.flux.klein_te(**detect)) + + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref)) + if len(detect) > 0: + detect["model_type"] = "qwen3_8b" + return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer8B, comfy.text_encoders.flux.klein_te(**detect)) + + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}mistral3_24b.transformer.".format(pref)) + if len(detect) > 0: + if "{}mistral3_24b.transformer.model.layers.39.post_attention_layernorm.weight".format(pref) not in state_dict: + detect["pruned"] = True + return supported_models_base.ClipTarget(comfy.text_encoders.flux.Flux2Tokenizer, comfy.text_encoders.flux.flux2_te(**detect)) + + return None class GenmoMochi(supported_models_base.BASE): unet_config = { @@ -836,6 +851,21 @@ class LTXV(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect)) +class LTXAV(LTXV): + unet_config = { + "image_model": "ltxav", + } + + latent_format = latent_formats.LTXAV + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = 0.077 # TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LTXAV(self, device=device) + return out + class HunyuanVideo(supported_models_base.BASE): unet_config = { "image_model": "hunyuan_video", @@ -977,6 +1007,36 @@ class CosmosT2IPredict2(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) +class Anima(supported_models_base.BASE): + unet_config = { + "image_model": "anima", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Anima(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect)) + class CosmosI2VPredict2(CosmosT2IPredict2): unet_config = { "image_model": "cosmos_predict2", @@ -1027,13 +1087,13 @@ class ZImage(Lumina2): "shift": 3.0, } - memory_usage_factor = 2.0 + memory_usage_factor = 2.8 supported_inference_dtypes = [torch.bfloat16, torch.float32] def __init__(self, unet_config): super().__init__(unet_config) - if comfy.model_management.extended_fp16_support(): + if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False): self.supported_inference_dtypes = self.supported_inference_dtypes.copy() self.supported_inference_dtypes.insert(1, torch.float16) @@ -1536,6 +1596,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 0e5f9a378..6c06ce19d 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): class TAEHV(nn.Module): - def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True), + latent_format=None, show_progress_bar=False): super().__init__() self.image_channels = 3 self.patch_size = 1 @@ -124,6 +125,9 @@ class TAEHV(nn.Module): self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 self.patch_size = 2 + elif self.latent_channels == 128: # LTX2 + self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True) + if self.latent_channels == 32: # HunyuanVideo1.5 act_func = nn.LeakyReLU(0.2, inplace=True) else: # HunyuanVideo, Wan 2.1 @@ -131,41 +135,52 @@ class TAEHV(nn.Module): self.encoder = nn.Sequential( conv(self.image_channels*self.patch_size**2, 64), act_func, - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), conv(64, self.latent_channels), ) n_f = [256, 128, 64, 64] - self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( Clamp(), conv(self.latent_channels, n_f[0]), act_func, - MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), - MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), - MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False), act_func, conv(n_f[3], self.image_channels*self.patch_size**2), ) - @property - def show_progress_bar(self): - return self._show_progress_bar - @show_progress_bar.setter - def show_progress_bar(self, value): - self._show_progress_bar = value + self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool)) + self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow)) + self.frames_to_trim = self.t_upscale - 1 + self._show_progress_bar = show_progress_bar + + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: - x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - if x.shape[1] % 4 != 0: - # pad at end to multiple of 4 - n_pad = 4 - x.shape[1] % 4 + if self.patch_size > 1: + B, T, C, H, W = x.shape + x = x.reshape(B * T, C, H, W) + x = F.pixel_unshuffle(x, self.patch_size) + x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size) + if x.shape[1] % self.t_downscale != 0: + # pad at end to multiple of t_downscale + n_pad = self.t_downscale - x.shape[1] % self.t_downscale padding = x[:, -1:].repeat_interleave(n_pad, dim=1) x = torch.cat([x, padding], 1) x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) return self.process_out(x) def decode(self, x, **kwargs): + x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] + x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) if self.patch_size > 1: diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py new file mode 100644 index 000000000..41f95bcb6 --- /dev/null +++ b/comfy/text_encoders/anima.py @@ -0,0 +1,61 @@ +from transformers import Qwen2Tokenizer, T5TokenizerFast +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch + + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class AnimaTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs) + out["qwen3_06b"] = [[(token, 1.0) for token, _ in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + return self.t5xxl.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class AnimaTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out = super().encode_token_weights(token_weight_pairs) + out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int) + out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0]))) + return out + +def te(dtype_llama=None, llama_quantization_metadata=None): + class AnimaTEModel_(AnimaTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return AnimaTEModel_ diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 448381fa9..f4b40ac68 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return CosmosTEModel_ diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 21d93d757..f67a5f805 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -3,7 +3,7 @@ import comfy.text_encoders.t5 import comfy.text_encoders.sd3_clip import comfy.text_encoders.llama import comfy.model_management -from transformers import T5TokenizerFast, LlamaTokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast, Qwen2Tokenizer import torch import os import json @@ -118,7 +118,7 @@ class MistralTokenizerClass: class Mistral3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): self.tekken_data = tokenizer_data.get("tekken_model", None) - super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) def state_dict(self): return {"tekken_model": self.tekken_data} @@ -172,3 +172,60 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): model_options["num_layers"] = 30 super().__init__(device=device, dtype=dtype, model_options=model_options) return Flux2TEModel_ + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + +class Qwen3Tokenizer8B(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + +class KleinTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"): + if name == "qwen3_4b": + tokenizer = Qwen3Tokenizer + elif name == "qwen3_8b": + tokenizer = Qwen3Tokenizer8B + + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name, tokenizer=tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class KleinTokenizer8B(KleinTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_8b"): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name) + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Qwen3_8BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_8B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +def klein_te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3_4b"): + if model_type == "qwen3_4b": + model = Qwen3_4BModel + elif model_type == "qwen3_8b": + model = Qwen3_8BModel + + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) + return Flux2TEModel_ diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 5daea8135..2d7a3fbce 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return MochiTEModel_ diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index a9a6c525e..2ddb4da60 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -10,9 +10,11 @@ import comfy.utils def llama_detect(state_dict, prefix=""): out = {} - t5_key = "{}model.norm.weight".format(prefix) - if t5_key in state_dict: - out["dtype_llama"] = state_dict[t5_key].dtype + norm_keys = ["{}model.norm.weight".format(prefix), "{}model.layers.0.input_layernorm.weight".format(prefix)] + for norm_key in norm_keys: + if norm_key in state_dict: + out["dtype_llama"] = state_dict[norm_key].dtype + break quant = comfy.utils.detect_layer_quantization(state_dict, prefix) if quant is not None: diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index faa4e1de8..68ac1e804 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -1,12 +1,13 @@ import torch import torch.nn as nn from dataclasses import dataclass -from typing import Optional, Any +from typing import Optional, Any, Tuple import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit +import comfy.clip_model from . import qwen_vl @@ -31,6 +32,7 @@ class Llama2Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Mistral3Small24BConfig: @@ -53,6 +55,7 @@ class Mistral3Small24BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_3BConfig: @@ -75,6 +78,30 @@ class Qwen25_3BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False + +@dataclass +class Qwen3_06BConfig: + vocab_size: int = 151936 + hidden_size: int = 1024 + intermediate_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False @dataclass class Qwen3_4BConfig: @@ -97,6 +124,30 @@ class Qwen3_4BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False + +@dataclass +class Qwen3_8BConfig: + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 12288 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False @dataclass class Ovis25_2BConfig: @@ -119,6 +170,7 @@ class Ovis25_2BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_7BVLI_Config: @@ -141,6 +193,7 @@ class Qwen25_7BVLI_Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Gemma2_2B_Config: @@ -164,6 +217,7 @@ class Gemma2_2B_Config: sliding_attention = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Gemma3_4B_Config: @@ -187,6 +241,33 @@ class Gemma3_4B_Config: sliding_attention = [1024, 1024, 1024, 1024, 1024, False] rope_scale = [8.0, 1.0] final_norm: bool = True + lm_head: bool = False + +@dataclass +class Gemma3_12B_Config: + vocab_size: int = 262208 + hidden_size: int = 3840 + intermediate_size: int = 15360 + num_hidden_layers: int = 48 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma3" + head_dim = 256 + rms_norm_add = True + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] + final_norm: bool = True + lm_head: bool = False + vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + mm_tokens_per_image = 256 class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -286,6 +367,7 @@ class Attention(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): batch_size, seq_length, _ = hidden_states.shape xq = self.q_proj(hidden_states) @@ -303,11 +385,30 @@ class Attention(nn.Module): xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) + present_key_value = None + if past_key_value is not None: + index = 0 + num_tokens = xk.shape[2] + if len(past_key_value) > 0: + past_key, past_value, index = past_key_value + if past_key.shape[2] >= (index + num_tokens): + past_key[:, :, index:index + xk.shape[2]] = xk + past_value[:, :, index:index + xv.shape[2]] = xv + xk = past_key[:, :, :index + xk.shape[2]] + xv = past_value[:, :, :index + xv.shape[2]] + present_key_value = (past_key, past_value, index + num_tokens) + else: + xk = torch.cat((past_key[:, :, :index], xk), dim=2) + xv = torch.cat((past_value[:, :, :index], xv), dim=2) + present_key_value = (xk, xv, index + num_tokens) + else: + present_key_value = (xk, xv, index + num_tokens) + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) - return self.o_proj(output) + return self.o_proj(output), present_key_value class MLP(nn.Module): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): @@ -338,15 +439,17 @@ class TransformerBlock(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, ) x = residual + x @@ -356,7 +459,7 @@ class TransformerBlock(nn.Module): x = self.mlp(x) x = residual + x - return x + return x, present_key_value class TransformerBlockGemma2(nn.Module): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): @@ -381,6 +484,7 @@ class TransformerBlockGemma2(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): if self.transformer_type == 'gemma3': if self.sliding_attention: @@ -398,11 +502,12 @@ class TransformerBlockGemma2(nn.Module): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, ) x = self.post_attention_layernorm(x) @@ -415,7 +520,7 @@ class TransformerBlockGemma2(nn.Module): x = self.post_feedforward_layernorm(x) x = residual + x - return x + return x, present_key_value class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): @@ -446,9 +551,10 @@ class Llama2_(nn.Module): else: self.norm = None - # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) + if config.lm_head: + self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): if embeds is not None: x = embeds else: @@ -457,8 +563,13 @@ class Llama2_(nn.Module): if self.normalize_in: x *= self.config.hidden_size ** 0.5 + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = past_key_values[0][2] + if position_ids is None: - position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, @@ -469,14 +580,16 @@ class Llama2_(nn.Module): mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - if mask is not None: - mask += causal_mask - else: - mask = causal_mask + if seq_len > 1: + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask + optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None @@ -492,16 +605,27 @@ class Llama2_(nn.Module): elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output + next_key_values = [] for i, layer in enumerate(self.layers): if all_intermediate is not None: if only_layers is None or (i in only_layers): all_intermediate.append(x.unsqueeze(1).clone()) - x = layer( + + past_kv = None + if past_key_values is not None: + past_kv = past_key_values[i] if len(past_key_values) > 0 else [] + + x, current_kv = layer( x=x, attention_mask=mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_kv, ) + + if current_kv is not None: + next_key_values.append(current_kv) + if i == intermediate_output: intermediate = x.clone() @@ -518,7 +642,45 @@ class Llama2_(nn.Module): if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: intermediate = self.norm(intermediate) - return x, intermediate + if len(next_key_values) > 0: + return x, intermediate, next_key_values + else: + return x, intermediate + + +class Gemma3MultiModalProjector(torch.nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype) + ) + + self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"]) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype)) + return projected_vision_outputs.type_as(vision_outputs) + class BaseLlama: def get_input_embeddings(self): @@ -558,6 +720,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_06B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_06BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Qwen3_4B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -567,6 +738,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_8B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_8BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Ovis25_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -636,3 +816,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype + +class Gemma3_12B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_12B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations) + self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations) + self.dtype = dtype + self.image_size = config.vision_config["image_size"] + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True) + return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None + return None, None diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 48ea67e67..26573fb12 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -1,7 +1,11 @@ from comfy import sd1_clip import os from transformers import T5TokenizerFast +from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.genmo +from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector +import torch +import comfy.utils class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -16,3 +20,133 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): def ltxv_te(*args, **kwargs): return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) + + +class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) + +class Gemma3_12BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs): + text = llama_template.format(text) + text_tokens = super().tokenize_with_weights(text, return_word_ids) + embed_count = 0 + for k in text_tokens: + tt = text_tokens[k] + for r in tt: + for i in range(len(r)): + if r[i][0] == 262144: + if image_embeds is not None and embed_count < image_embeds.shape[0]: + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return text_tokens + +class LTXAVTEModel(torch.nn.Module): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = set() + self.dtypes.add(dtype) + + self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) + self.dtypes.add(dtype_llama) + + operations = self.gemma3_12b.operations # TODO + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + self.audio_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + self.video_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def set_clip_options(self, options): + self.execution_device = options.get("execution_device", self.execution_device) + self.gemma3_12b.set_clip_options(options) + + def reset_clip_options(self): + self.gemma3_12b.reset_clip_options() + self.execution_device = None + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs["gemma3_12b"] + + out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) + out_device = out.device + if comfy.model_management.should_use_bf16(self.execution_device): + out = out.to(device=self.execution_device, dtype=torch.bfloat16) + out = out.movedim(1, -1).to(self.execution_device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) + out = out.float() + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + + return out.to(out_device), pooled + + def load_sd(self, sd): + if "model.layers.47.self_attn.q_norm.weight" in sd: + return self.gemma3_12b.load_sd(sd) + else: + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) + if len(sdo) == 0: + sdo = sd + + missing_all = [] + unexpected_all = [] + + for prefix, component in [("text_embedding_projection.", self.text_embedding_projection), ("video_embeddings_connector.", self.video_embeddings_connector), ("audio_embeddings_connector.", self.audio_embeddings_connector)]: + component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} + if component_sd: + missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) + missing_all.extend([f"{prefix}{k}" for k in missing]) + unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) + + return (missing_all, unexpected_all) + + def memory_estimation_function(self, token_weight_pairs, device=None): + constant = 6.0 + if comfy.model_management.should_use_bf16(device): + constant /= 2.0 + + token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + return num_tokens * constant * 1024 * 1024 + +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): + class LTXAVTEModel_(LTXAVTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return LTXAVTEModel_ diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index 5754424d2..2cc0867c3 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return OvisTEModel_ diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index e5e5f18be..51c6e50c7 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return PixArtTEModel_ diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index 19adde0b7..ad41bfb1e 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return ZImageTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index e4162d7ac..c1b536833 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -28,8 +28,11 @@ import logging import itertools from torch.nn.functional import interpolate from einops import rearrange -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram import json +import time +import mmap +import warnings MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -55,21 +58,70 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in else: logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") +# Current as of safetensors 0.7.0 +_TYPES = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + "C64": torch.complex64, + + "U64": torch.uint64, + "U32": torch.uint32, + "U16": torch.uint16, +} + +def load_safetensors(ckpt): + f = open(ckpt, "rb") + mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + + header_size = struct.unpack(" 0: message = e.args[0] @@ -610,6 +662,14 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.net.0.proj.bias": "txt_mlp.0.bias", "ff_context.net.2.weight": "txt_mlp.2.weight", "ff_context.net.2.bias": "txt_mlp.2.bias", + "ff.linear_in.weight": "img_mlp.0.weight", # LyCoris LoKr + "ff.linear_in.bias": "img_mlp.0.bias", + "ff.linear_out.weight": "img_mlp.2.weight", + "ff.linear_out.bias": "img_mlp.2.bias", + "ff_context.linear_in.weight": "txt_mlp.0.weight", + "ff_context.linear_in.bias": "txt_mlp.0.bias", + "ff_context.linear_out.weight": "txt_mlp.2.weight", + "ff_context.linear_out.bias": "txt_mlp.2.bias", "attn.norm_q.weight": "img_attn.norm.query_norm.scale", "attn.norm_k.weight": "img_attn.norm.key_norm.scale", "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", @@ -638,6 +698,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "proj_out.bias": "linear2.bias", "attn.norm_q.weight": "norm.query_norm.scale", "attn.norm_k.weight": "norm.key_norm.scale", + "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 + "attn.to_out.weight": "linear2.weight", # Flux 2 } for k in block_map: @@ -928,7 +990,9 @@ def bislerp(samples, width, height): return result.to(orig_dtype) def lanczos(samples, width, height): - images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] + #the below API is strict and expects grayscale to be squeezed + samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1) + images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) @@ -1097,6 +1161,10 @@ def set_progress_bar_global_hook(function): global PROGRESS_BAR_HOOK PROGRESS_BAR_HOOK = function +# Throttle settings for progress bar updates to reduce WebSocket flooding +PROGRESS_THROTTLE_MIN_INTERVAL = 0.1 # 100ms minimum between updates +PROGRESS_THROTTLE_MIN_PERCENT = 0.5 # 0.5% minimum progress change + class ProgressBar: def __init__(self, total, node_id=None): global PROGRESS_BAR_HOOK @@ -1104,6 +1172,8 @@ class ProgressBar: self.current = 0 self.hook = PROGRESS_BAR_HOOK self.node_id = node_id + self._last_update_time = 0.0 + self._last_sent_value = -1 def update_absolute(self, value, total=None, preview=None): if total is not None: @@ -1112,7 +1182,29 @@ class ProgressBar: value = self.total self.current = value if self.hook is not None: - self.hook(self.current, self.total, preview, node_id=self.node_id) + current_time = time.perf_counter() + is_first = (self._last_sent_value < 0) + is_final = (value >= self.total) + has_preview = (preview is not None) + + # Always send immediately for previews, first update, or final update + if has_preview or is_first or is_final: + self.hook(self.current, self.total, preview, node_id=self.node_id) + self._last_update_time = current_time + self._last_sent_value = value + return + + # Apply throttling for regular progress updates + if self.total > 0: + percent_changed = ((value - max(0, self._last_sent_value)) / self.total) * 100 + else: + percent_changed = 100 + time_elapsed = current_time - self._last_update_time + + if time_elapsed >= PROGRESS_THROTTLE_MIN_INTERVAL and percent_changed >= PROGRESS_THROTTLE_MIN_PERCENT: + self.hook(self.current, self.total, preview, node_id=self.node_id) + self._last_update_time = current_time + self._last_sent_value = value def update(self, value): self.update_absolute(self.current + value) @@ -1198,7 +1290,7 @@ def unpack_latents(combined_latent, latent_shapes): combined_latent = combined_latent[:, :, cut:] output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) else: - output_tensors = combined_latent + output_tensors = [combined_latent] return output_tensors def detect_layer_quantization(state_dict, prefix): @@ -1267,3 +1359,16 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata + +def string_to_seed(data): + crc = 0xFFFFFFFF + for byte in data: + if isinstance(byte, str): + byte = ord(byte) + crc ^= byte + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + return crc ^ 0xFFFFFFFF diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index b40f920e4..b9fa8d5cf 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -5,6 +5,11 @@ from .lokr import LoKrAdapter from .glora import GLoRAAdapter from .oft import OFTAdapter from .boft import BOFTAdapter +from .bypass import ( + BypassInjectionManager, + BypassForwardHook, + create_bypass_injections_from_patches, +) adapters: list[type[WeightAdapterBase]] = [ @@ -31,4 +36,7 @@ __all__ = [ "WeightAdapterTrainBase", "adapters", "adapter_maps", + "BypassInjectionManager", + "BypassForwardHook", + "create_bypass_injections_from_patches", ] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 43644b106..bce89a0e2 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional import torch import torch.nn as nn @@ -7,12 +7,35 @@ import comfy.model_management class WeightAdapterBase: + """ + Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT/BOFT: g = transform, h = 0 + """ + name: str loaded_keys: set[str] weights: list[torch.Tensor] + # Attributes set by bypass system + multiplier: float = 1.0 + shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel) + @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + ) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": @@ -39,18 +62,202 @@ class WeightAdapterBase: ): raise NotImplementedError + # ===== Bypass Mode Methods ===== + # + # IMPORTANT: Bypass mode is designed for quantized models where original weights + # may not be accessible in a usable format. Therefore, h() and bypass_forward() + # do NOT take org_weight as a parameter. All necessary information (out_channels, + # in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook. + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT/BOFT), returns zeros. + + Note: + This method does NOT access original model weights. Bypass mode is + designed for quantized models where weights may not be in a usable format. + All shape info comes from module attributes set by BypassForwardHook. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Reference: LyCORIS LoConModule.bypass_forward_diff + """ + # Default: no additive component (for OFT/BOFT) + # Simply return zeros matching base_out shape + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT/BOFT override this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + + Reference: LyCORIS OFTModule applies orthogonal transform here + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Note: + This method does NOT take org_weight/org_bias parameters. Bypass mode + is designed for quantized models where weights may not be accessible. + The original forward function handles weight access internally. + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + + Reference: LyCORIS LoConModule.bypass_forward + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + class WeightAdapterTrainBase(nn.Module): - # We follow the scheme of PR #7032 + """ + Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT: g = transform, h = 0 + + Note: + Unlike WeightAdapterBase, TrainBase classes have simplified weight formats + with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition). + + We follow the scheme of PR #7032 + """ + + # Attributes set by bypass system (BypassForwardHook) + # These are set before h()/g()/bypass_forward() are called + multiplier: float = 1.0 + is_conv: bool = False + conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d + kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups + kernel_size: tuple = () + in_channels: int = None + out_channels: int = None + def __init__(self): super().__init__() def __call__(self, w): """ - w: The original weight tensor to be modified. + Weight modification mode: returns modified weight. + + Args: + w: The original weight tensor to be modified. + + Returns: + Modified weight tensor. """ raise NotImplementedError + # ===== Bypass Mode Methods ===== + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT), returns zeros. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Subclasses should override this method. + """ + raise NotImplementedError( + f"{self.__class__.__name__}.h() not implemented. " + "Subclasses must implement h() for bypass mode." + ) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT overrides this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + def passive_memory_usage(self): raise NotImplementedError("passive_memory_usage is not implemented") @@ -59,8 +266,12 @@ class WeightAdapterTrainBase(nn.Module): return self.passive_memory_usage() -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) +def weight_decompose( + dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function +): + dora_scale = comfy.model_management.cast_to_device( + dora_scale, weight.device, intermediate_dtype + ) lora_diff *= alpha weight_calc = weight + function(lora_diff).type(weight.dtype) @@ -106,10 +317,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten the original tensor will be truncated in that dimension. """ if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): - raise ValueError("The new shape must be larger than the original tensor in all dimensions") + raise ValueError( + "The new shape must be larger than the original tensor in all dimensions" + ) if len(new_shape) != len(tensor.shape): - raise ValueError("The new shape must have the same number of dimensions as the original tensor") + raise ValueError( + "The new shape must have the same number of dimensions as the original tensor" + ) # Create a new tensor filled with zeros padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py index b2a2f1bd4..02a8dc130 100644 --- a/comfy/weight_adapter/boft.py +++ b/comfy/weight_adapter/boft.py @@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase): alpha = v[2] dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) boft_m, block_num, boft_b, *_ = blocks.shape @@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(-1, -2) normed_q = q - if alpha > 0: # alpha in boft/bboft is for constraint + if alpha > 0: # alpha in boft/bboft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm @@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase): r = r.to(weight) inp = org = weight - r_b = boft_b//2 + r_b = boft_b // 2 for i in range(boft_m): bi = r[i] g = 2 k = 2**i * r_b if strength != 1: - bi = bi * strength + (1-strength) * I + bi = bi * strength + (1 - strength) * I inp = ( inp.unflatten(0, (-1, g, k)) .transpose(1, 2) @@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase): ) inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp) inp = ( - inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2) + inp.flatten(0, 1) + .unflatten(0, (-1, k, g)) + .transpose(1, 2) + .flatten(0, 2) ) if rescale is not None: inp = inp * rescale lora_diff = inp - org - lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype) + lora_diff = comfy.model_management.cast_to_device( + lora_diff, weight.device, intermediate_dtype + ) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrices(self, device, dtype): + """Compute the orthogonal rotation matrices R from BOFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + boft_m, block_num, boft_b, _ = blocks.shape + I = torch.eye(boft_b, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(-1, -2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, boft_m, boft_b + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for BOFT: applies butterfly orthogonal transform. + + BOFT uses multiple stages of butterfly-structured orthogonal transforms. + + Reference: LyCORIS ButterflyOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype) + r_b = boft_b // 2 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(boft_b, device=y.device, dtype=y.dtype) + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # Apply butterfly transform stages + inp = y + for i in range(boft_m): + bi = r[i] # (block_num, boft_b, boft_b) + g = 2 + k = 2**i * r_b + + # Interpolate with identity based on multiplier + if multiplier != 1: + bi = bi * multiplier + (1 - multiplier) * I + + # Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, boft_b)) + ) + # Apply block-diagonal orthogonal transform + inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) + # Reshape back + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + inp = inp * rescale.transpose(0, -1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + inp = inp.transpose(1, -1) + + return inp diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py new file mode 100644 index 000000000..d4aaf98ca --- /dev/null +++ b/comfy/weight_adapter/bypass.py @@ -0,0 +1,437 @@ +""" +Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.) + +Bypass mode applies adapters during forward pass without modifying base weights: + bypass(f)(x) = g(f(x) + h(x)) + +Where: + - f(x): Original layer forward + - h(x): Additive component from adapter (LoRA path) + - g(y): Output transformation (identity for most adapters) + +This is useful for: + - Training with gradient checkpointing + - Avoiding weight modifications when weights are offloaded + - Supporting multiple adapters with different strengths dynamically +""" + +import logging +from typing import Optional, Union + +import torch +import torch.nn as nn + +from .base import WeightAdapterBase, WeightAdapterTrainBase +from comfy.patcher_extension import PatcherInjection + +# Type alias for adapters that support bypass mode +BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase] + + +def get_module_type_info(module: nn.Module) -> dict: + """ + Determine module type and extract conv parameters from module class. + + This is more reliable than checking weight.ndim, especially for quantized layers + where weight shape might be different. + + Returns: + dict with keys: is_conv, conv_dim, stride, padding, dilation, groups + """ + info = { + "is_conv": False, + "conv_dim": 0, + "stride": (1,), + "padding": (0,), + "dilation": (1,), + "groups": 1, + "kernel_size": (1,), + "in_channels": None, + "out_channels": None, + } + + # Determine conv type + if isinstance(module, nn.Conv1d): + info["is_conv"] = True + info["conv_dim"] = 1 + elif isinstance(module, nn.Conv2d): + info["is_conv"] = True + info["conv_dim"] = 2 + elif isinstance(module, nn.Conv3d): + info["is_conv"] = True + info["conv_dim"] = 3 + elif isinstance(module, nn.Linear): + info["is_conv"] = False + info["conv_dim"] = 0 + else: + # Try to infer from class name for custom/quantized layers + class_name = type(module).__name__.lower() + if "conv3d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 3 + elif "conv2d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + elif "conv1d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 1 + elif "conv" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + + # Extract conv parameters if it's a conv layer + if info["is_conv"]: + # Try to get stride, padding, dilation, groups, kernel_size from module + info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"]) + info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"]) + info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"]) + info["groups"] = getattr(module, "groups", 1) + info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"]) + info["in_channels"] = getattr(module, "in_channels", None) + info["out_channels"] = getattr(module, "out_channels", None) + + # Ensure they're tuples + if isinstance(info["stride"], int): + info["stride"] = (info["stride"],) * info["conv_dim"] + if isinstance(info["padding"], int): + info["padding"] = (info["padding"],) * info["conv_dim"] + if isinstance(info["dilation"], int): + info["dilation"] = (info["dilation"],) * info["conv_dim"] + if isinstance(info["kernel_size"], int): + info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"] + + return info + + +class BypassForwardHook: + """ + Hook that wraps a layer's forward to apply adapter in bypass mode. + + Stores the original forward and replaces it with bypass version. + + Supports both: + - WeightAdapterBase: Inference adapters (uses self.weights tuple) + - WeightAdapterTrainBase: Training adapters (nn.Module with parameters) + """ + + def __init__( + self, + module: nn.Module, + adapter: BypassAdapter, + multiplier: float = 1.0, + ): + self.module = module + self.adapter = adapter + self.multiplier = multiplier + self.original_forward = None + + # Determine layer type and conv params from module class (works for quantized layers) + module_info = get_module_type_info(module) + + # Set multiplier and layer type info on adapter for use in h() + adapter.multiplier = multiplier + adapter.is_conv = module_info["is_conv"] + adapter.conv_dim = module_info["conv_dim"] + adapter.kernel_size = module_info["kernel_size"] + adapter.in_channels = module_info["in_channels"] + adapter.out_channels = module_info["out_channels"] + # Store kw_dict for conv operations (like LyCORIS extra_args) + if module_info["is_conv"]: + adapter.kw_dict = { + "stride": module_info["stride"], + "padding": module_info["padding"], + "dilation": module_info["dilation"], + "groups": module_info["groups"], + } + else: + adapter.kw_dict = {} + + def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x)) + + Note: + Bypass mode does NOT access original model weights (org_weight). + This is intentional - bypass mode is designed for quantized models + where weights may not be in a usable format. All necessary shape + information is provided via adapter attributes set during inject(). + """ + # Check if adapter has custom bypass_forward (e.g., GLoRA) + adapter_bypass = getattr(self.adapter, "bypass_forward", None) + if adapter_bypass is not None: + # Check if it's overridden (not the base class default) + # Need to check both base classes since adapter could be either type + adapter_type = type(self.adapter) + is_default_bypass = ( + adapter_type.bypass_forward is WeightAdapterBase.bypass_forward + or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward + ) + if not is_default_bypass: + return adapter_bypass(self.original_forward, x, *args, **kwargs) + + # Default bypass: g(f(x) + h(x, f(x))) + base_out = self.original_forward(x, *args, **kwargs) + h_out = self.adapter.h(x, base_out) + return self.adapter.g(base_out + h_out) + + def inject(self): + """Replace module forward with bypass version.""" + if self.original_forward is not None: + logging.debug( + f"[BypassHook] Already injected for {type(self.module).__name__}" + ) + return # Already injected + + # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward + device = None + dtype = None + if hasattr(self.module, "weight") and self.module.weight is not None: + device = self.module.weight.device + dtype = self.module.weight.dtype + elif hasattr(self.module, "W_q"): # Quantized layers might use different attr + device = self.module.W_q.device + dtype = self.module.W_q.dtype + + if device is not None: + self._move_adapter_weights_to_device(device, dtype) + + self.original_forward = self.module.forward + self.module.forward = self._bypass_forward + logging.debug( + f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})" + ) + + def _move_adapter_weights_to_device(self, device, dtype=None): + """Move adapter weights to specified device to avoid per-forward transfers. + + Handles both: + - WeightAdapterBase: has self.weights tuple of tensors + - WeightAdapterTrainBase: nn.Module with parameters, uses .to() method + """ + adapter = self.adapter + + # Check if adapter is an nn.Module (WeightAdapterTrainBase) + if isinstance(adapter, nn.Module): + # In training mode we don't touch dtype as trainer will handle it + adapter.to(device=device) + logging.debug( + f"[BypassHook] Moved training adapter (nn.Module) to {device}" + ) + return + + # WeightAdapterBase: handle self.weights tuple + if not hasattr(adapter, "weights") or adapter.weights is None: + return + + weights = adapter.weights + if isinstance(weights, (list, tuple)): + new_weights = [] + for w in weights: + if isinstance(w, torch.Tensor): + if dtype is not None: + new_weights.append(w.to(device=device, dtype=dtype)) + else: + new_weights.append(w.to(device=device)) + else: + new_weights.append(w) + adapter.weights = ( + tuple(new_weights) if isinstance(weights, tuple) else new_weights + ) + elif isinstance(weights, torch.Tensor): + if dtype is not None: + adapter.weights = weights.to(device=device, dtype=dtype) + else: + adapter.weights = weights.to(device=device) + + logging.debug(f"[BypassHook] Moved adapter weights to {device}") + + def eject(self): + """Restore original module forward.""" + if self.original_forward is None: + logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}") + return # Not injected + + self.module.forward = self.original_forward + self.original_forward = None + logging.debug( + f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}" + ) + + +class BypassInjectionManager: + """ + Manages bypass mode injection for a collection of adapters. + + Creates PatcherInjection objects that can be used with ModelPatcher. + + Supports both inference adapters (WeightAdapterBase) and training adapters + (WeightAdapterTrainBase). + + Usage: + manager = BypassInjectionManager() + manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8) + manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8) + + injections = manager.create_injections(model) + model_patcher.set_injections("bypass_lora", injections) + """ + + def __init__(self): + self.adapters: dict[str, tuple[BypassAdapter, float]] = {} + self.hooks: list[BypassForwardHook] = [] + + def add_adapter( + self, + key: str, + adapter: BypassAdapter, + strength: float = 1.0, + ): + """ + Add an adapter for a specific weight key. + + Args: + key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight") + adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.) + strength: Multiplier for adapter effect + """ + # Remove .weight suffix if present for module lookup + module_key = key + if module_key.endswith(".weight"): + module_key = module_key[:-7] + logging.debug( + f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}" + ) + + self.adapters[module_key] = (adapter, strength) + logging.debug( + f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})" + ) + + def clear_adapters(self): + """Remove all adapters.""" + self.adapters.clear() + + def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]: + """Get a submodule by dot-separated key.""" + parts = key.split(".") + module = model + try: + for i, part in enumerate(parts): + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + logging.debug( + f"[BypassManager] Found module for key {key}: {type(module).__name__}" + ) + return module + except (AttributeError, IndexError, KeyError) as e: + logging.error(f"[BypassManager] Failed to find module for key {key}: {e}") + logging.error( + f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}" + ) + return None + + def create_injections(self, model: nn.Module) -> list[PatcherInjection]: + """ + Create PatcherInjection objects for all registered adapters. + + Args: + model: The model to inject into (e.g., model_patcher.model) + + Returns: + List of PatcherInjection objects to use with model_patcher.set_injections() + """ + self.hooks.clear() + + logging.debug( + f"[BypassManager] create_injections called with {len(self.adapters)} adapters" + ) + logging.debug(f"[BypassManager] Model type: {type(model).__name__}") + + for key, (adapter, strength) in self.adapters.items(): + logging.debug(f"[BypassManager] Looking for module: {key}") + module = self._get_module_by_key(model, key) + + if module is None: + logging.warning(f"[BypassManager] Module not found for key {key}") + continue + + if not hasattr(module, "weight"): + logging.warning( + f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})" + ) + continue + + logging.debug( + f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})" + ) + hook = BypassForwardHook(module, adapter, multiplier=strength) + self.hooks.append(hook) + + logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks") + + # Create single injection that manages all hooks + def inject_all(model_patcher): + logging.debug( + f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.inject() + logging.debug( + f"[BypassManager] Injected hook for {type(hook.module).__name__}" + ) + + def eject_all(model_patcher): + logging.debug( + f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.eject() + + return [PatcherInjection(inject=inject_all, eject=eject_all)] + + def get_hook_count(self) -> int: + """Return number of hooks that will be/are injected.""" + return len(self.hooks) + + +def create_bypass_injections_from_patches( + model: nn.Module, + patches: dict, + strength: float = 1.0, +) -> list[PatcherInjection]: + """ + Convenience function to create bypass injections from a patches dict. + + This is useful when you have patches in the format used by model_patcher.add_patches() + and want to apply them in bypass mode instead. + + Args: + model: The model to inject into + patches: Dict mapping weight keys to adapter data + strength: Global strength multiplier + + Returns: + List of PatcherInjection objects + """ + manager = BypassInjectionManager() + + for key, patch_list in patches.items(): + if not patch_list: + continue + + # patches format: list of (strength_patch, patch_data, strength_model, offset, function) + for patch in patch_list: + patch_strength, patch_data, strength_model, offset, function = patch + + # patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple + if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)): + adapter = patch_data + else: + # Skip non-adapter patches + continue + + combined_strength = strength * patch_strength + manager.add_adapter(key, adapter, strength=combined_strength) + + return manager.create_injections(model) diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 939abbba5..d6b97a23b 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -1,7 +1,8 @@ import logging -from typing import Optional +from typing import Callable, Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, weight_decompose @@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) + weights = ( + lora[a1_name], + lora[a2_name], + lora[b1_name], + lora[b2_name], + alpha, + dora_scale, + ) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) @@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase): old_glora = True if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + if ( + old_glora + and v[1].shape[0] == weight.shape[0] + and weight.shape[0] == weight.shape[1] + ): pass else: old_glora = False rank = v[1].shape[0] - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + a1 = comfy.model_management.cast_to_device( + v[0].flatten(start_dim=1), weight.device, intermediate_dtype + ) + a2 = comfy.model_management.cast_to_device( + v[1].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b1 = comfy.model_management.cast_to_device( + v[2].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b2 = comfy.model_management.cast_to_device( + v[3].flatten(start_dim=1), weight.device, intermediate_dtype + ) if v[4] is not None: alpha = v[4] / rank @@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase): try: if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + lora_diff = ( + torch.mm(b2, b1) + + torch.mm( + torch.mm( + weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2 + ), + a1, + ) + ).reshape( + weight.shape + ) # old lycoris glora else: if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.einsum( + "o i ..., i j -> o j ...", + torch.einsum( + "o i ..., i j -> o j ...", + weight.to(dtype=intermediate_dtype), + a1, + ), + a2, + ).reshape(weight.shape) else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.mm( + torch.mm(weight.to(dtype=intermediate_dtype), a1), a2 + ).reshape(weight.shape) lora_diff += torch.mm(b1, b2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _compute_paths(self, x: torch.Tensor): + """ + Compute A path and B path outputs for GLoRA bypass. + + GLoRA: f(x) = Wx + WAx + Bx + - A path: a1(a2(x)) - modifies input to base forward + - B path: b1(b2(x)) - additive component + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Returns: (a_out, b_out) + """ + v = self.weights + # v = (a1, a2, b1, b2, alpha, dora_scale) + a1 = v[0] + a2 = v[1] + b1 = v[2] + b2 = v[3] + alpha = v[4] + + dtype = x.dtype + + # Cast dtype (weights should already be on correct device from inject()) + a1 = a1.to(dtype=dtype) + a2 = a2.to(dtype=dtype) + b1 = b1.to(dtype=dtype) + b2 = b2.to(dtype=dtype) + + # Determine rank and scale + # Check for old vs new glora format + old_glora = False + if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]: + rank = a1.shape[0] + old_glora = True + + if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]: + if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]: + pass + else: + old_glora = False + rank = a2.shape[0] + + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + scale = scale * multiplier + + # Use module info from bypass injection, not input tensor shape + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + # Conv case - conv_dim is 1/2/3 for conv1d/2d/3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Get module's stride/padding for spatial dimension handling + module_stride = kw_dict.get("stride", (1,) * conv_dim) + module_padding = kw_dict.get("padding", (0,) * conv_dim) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Ensure weights are in conv shape + # a1, a2, b1 are always 1x1 kernels + if a1.ndim == 2: + a1 = a1.view(*a1.shape, *([1] * conv_dim)) + if a2.ndim == 2: + a2 = a2.view(*a2.shape, *([1] * conv_dim)) + if b1.ndim == 2: + b1 = b1.view(*b1.shape, *([1] * conv_dim)) + # b2 has actual kernel_size (like LoRA down) + if b2.ndim == 2: + if in_channels is not None: + b2 = b2.view(b2.shape[0], in_channels, *kernel_size) + else: + b2 = b2.view(*b2.shape, *([1] * conv_dim)) + + # A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x + a2_out = conv_fn(x, a2) + a_out = conv_fn(a2_out, a1) * scale + + # B path: b2(x) with kernel/stride/padding -> b1(...) 1x1 + b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding) + b_out = conv_fn(b2_out, b1) * scale + else: + # Linear case + if old_glora: + # Old format: a1 @ a2 @ x, b2 @ b1 + a_out = F.linear(F.linear(x, a2), a1) * scale + b_out = F.linear(F.linear(x, b1), b2) * scale + else: + # New format: x @ a1 @ a2, b1 @ b2 + a_out = F.linear(F.linear(x, a1), a2) * scale + b_out = F.linear(F.linear(x, b2), b1) * scale + + return a_out, b_out + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + GLoRA bypass forward: f(x + a(x)) + b(x) + + Unlike standard adapters, GLoRA modifies the input to the base forward + AND adds the B path output. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Reference: LyCORIS GLoRAModule._bypass_forward + """ + a_out, b_out = self._compute_paths(x) + + # Call base forward with modified input + base_out = org_forward(x + a_out, *args, **kwargs) + + # Add B path + return base_out + b_out + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + For GLoRA, h() returns the B path output. + + Note: + GLoRA's full bypass requires overriding bypass_forward() since + it also modifies the input to org_forward. This h() is provided for + compatibility but bypass_forward() should be used for correct behavior. + + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + _, b_out = self._compute_paths(x) + return b_out diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 0abb2d403..8007b7b44 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -1,11 +1,22 @@ import logging +from functools import cache from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose +@cache +def _warn_loha_bypass_inefficient(): + """One-time warning about LoHa bypass inefficiency.""" + logging.warning( + "LoHa bypass mode is inefficient: full weight diff is computed each forward pass. " + "Consider using LoRA or LoKr for training with bypass mode." + ) + + class HadaWeight(torch.autograd.Function): @staticmethod def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)): @@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase): scale = self.alpha / self.rank if self.use_tucker: - diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeightTucker.apply( + self.hada_t1, + self.hada_w1_a, + self.hada_w1_b, + self.hada_t2, + self.hada_w2_a, + self.hada_w2_b, + scale, + ) else: - diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeight.apply( + self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale + ) # Add the scaled difference to the original weight weight = w.to(diff_weight) + diff_weight.reshape(w.shape) @@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase): mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat4, 0.01) - return LohaDiff( - (mat1, mat2, alpha, mat3, mat4, None, None, None) - ) + return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None)) def to_train(self): return LohaDiff(self.weights) @@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) + weights = ( + lora[hada_w1_a_name], + lora[hada_w1_b_name], + alpha, + lora[hada_w2_a_name], + lora[hada_w2_b_name], + hada_t1, + hada_t2, + dora_scale, + ) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase): w2a = v[3] w2b = v[4] dora_scale = v[7] - if v[5] is not None: #cp decomposition + if v[5] is not None: # cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + m1 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t1, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + ) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + m2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + ) else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + m1 = torch.mm( + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + ) + m2 = torch.mm( + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + ) try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoHa: h(x) = diff_weight @ x + + WARNING: Inefficient - computes full Hadamard product each forward. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/loha.py bypass_forward_diff + """ + _warn_loha_bypass_inefficient() + + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora + w1a = v[0] + w1b = v[1] + alpha = v[2] + w2a = v[3] + w2b = v[4] + t1 = v[5] + t2 = v[6] + + # Compute scale + rank = w1b.shape[0] + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Cast dtype + w1a = w1a.to(dtype=x.dtype) + w1b = w1b.to(dtype=x.dtype) + w2a = w2a.to(dtype=x.dtype) + w2b = w2b.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Compute diff weight using Hadamard product + if t1 is not None and t2 is not None: + t1 = t1.to(dtype=x.dtype) + t2 = t2.to(dtype=x.dtype) + m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a) + m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a) + diff_weight = (m1 * m2) * scale + else: + m1 = w1a @ w1b + m2 = w2a @ w2b + diff_weight = (m1 * m2) * scale + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D diff_weight to conv format using kernel_size + # diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size] + if diff_weight.dim() == 2: + if in_channels is not None: + diff_weight = diff_weight.view( + diff_weight.shape[0], in_channels, *kernel_size + ) + else: + diff_weight = diff_weight.view( + *diff_weight.shape, *([1] * conv_dim) + ) + else: + op = F.linear + kw_dict = {} + + return op(x, diff_weight, **kw_dict) diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 9b2aff2d7..b83750012 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -14,7 +15,17 @@ from .base import ( class LokrDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights + ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) = weights self.use_tucker = False if lokr_w1_a is not None: _, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1] @@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase): if self.w2_rebuild: if self.use_tucker: w2 = torch.einsum( - 'i j k l, j r, i p -> p r k l', + "i j k l, j r, i p -> p r k l", self.lokr_t2, self.lokr_w2_b, - self.lokr_w2_a + self.lokr_w2_a, ) else: w2 = self.lokr_w2_a @ self.lokr_w2_b @@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase): return self.lokr_w2 def __call__(self, w): - diff = torch.kron(self.w1, self.w2) + w1 = self.w1 + w2 = self.w2 + # Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron) + for _ in range(w2.dim() - w1.dim()): + w1 = w1.unsqueeze(-1) + diff = torch.kron(w1, w2) return w + diff.reshape(w.shape).to(w) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr training: efficient Kronecker product. + + Uses w1/w2 properties which handle both direct and decomposed cases. + For create_train (direct w1/w2), no alpha scaling in properties. + For to_train (decomposed), alpha/rank scaling is in properties. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Get w1, w2 from properties (handles rebuild vs direct) + w1 = self.w1 + w2 = self.w2 + + # Multiplier from bypass injection + multiplier = getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Efficient Kronecker application without materializing full weight + # kron(w1, w2) @ x can be computed as nested operations + # w1: [out_l, in_m], w2: [out_k, in_n, *k_size] + # Full weight would be [out_l*out_k, in_m*in_n, *k_size] + + uq = w1.size(1) # in_m - inner grouping dimension + + if is_conv: + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + B, C_in, *spatial = x.shape + # Reshape input for grouped application: [B * uq, C_in // uq, *spatial] + h_in_group = x.reshape(B * uq, -1, *spatial) + + # Ensure w2 has conv dims + if w2.dim() == 2: + w2 = w2.view(*w2.shape, *([1] * conv_dim)) + + # Apply w2 path with stride/padding + hb = conv_fn(h_in_group, w2, **kw_dict) + + # Reshape for cross-group operation + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross = hb.transpose(1, -1) + + # Apply w1 (always 2D, applied as linear on channel dim) + hc = F.linear(h_cross, w1) + hc = hc.transpose(1, -1) + + # Reshape to output + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + # Linear case + # Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n] + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k] + hb = F.linear(h_in_group, w2) + + # Transpose for w1: [..., uq, out_k] -> [..., out_k, uq] + h_cross = hb.transpose(-1, -2) + + # Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l] + hc = F.linear(h_cross, w1) + + # Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k] + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * multiplier + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase): @classmethod def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] - in_dim = weight.shape[1:].numel() - out1, out2 = factorization(out_dim, rank) - in1, in2 = factorization(in_dim, rank) - mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32) - mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32) + in_dim = weight.shape[1] # Just in_channels, not flattened with kernel + k_size = weight.shape[2:] if weight.dim() > 2 else () + + out_l, out_k = factorization(out_dim, rank) + in_m, in_n = factorization(in_dim, rank) + + # w1: [out_l, in_m] + mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32) + # w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear + mat2 = torch.empty( + out_k, in_n, *k_size, device=weight.device, dtype=torch.float32 + ) + torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.constant_(mat1, 0.0) - return LokrDiff( - (mat1, mat2, alpha, None, None, None, None, None, None) - ) + return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None)) def to_train(self): return LokrDiff(self.weights) @@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase): lokr_t2 = lora[lokr_t2_name] loaded_keys.add(lokr_t2_name) - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) + if ( + (lokr_w1 is not None) + or (lokr_w2 is not None) + or (lokr_w1_a is not None) + or (lokr_w2_a is not None) + ): + weights = ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) return cls(loaded_keys, weights) else: return None @@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase): if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + w1 = torch.mm( + comfy.model_management.cast_to_device( + w1_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1_b, weight.device, intermediate_dtype + ), + ) else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + w1 = comfy.model_management.cast_to_device( + w1, weight.device, intermediate_dtype + ) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + w2 = torch.mm( + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + ) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + w2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + ) else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + w2 = comfy.model_management.cast_to_device( + w2, weight.device, intermediate_dtype + ) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase): try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr: efficient Kronecker product application. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/lokr.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora + w1 = v[0] + w2 = v[1] + alpha = v[2] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + + use_w1 = w1 is not None + use_w2 = w2 is not None + tucker = t2 is not None + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) if is_conv else {} + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + else: + op = F.linear + + # Determine rank and scale + rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Build c (w1) + if use_w1: + c = w1.to(dtype=x.dtype) + else: + c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype) + uq = c.size(1) + + # Build w2 components + if use_w2: + ba = w2.to(dtype=x.dtype) + else: + a = w2_b.to(dtype=x.dtype) + b = w2_a.to(dtype=x.dtype) + if is_conv: + if tucker: + # Tucker: a, b get 1s appended (kernel is in t2) + if a.dim() == 2: + a = a.view(*a.shape, *([1] * conv_dim)) + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + else: + # Non-tucker conv: b may need 1s appended + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + + # Reshape input by uq groups + if is_conv: + B, _, *rest = x.shape + h_in_group = x.reshape(B * uq, -1, *rest) + else: + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2 path + if use_w2: + hb = op(h_in_group, ba, **kw_dict) + else: + if is_conv: + if tucker: + t = t2.to(dtype=x.dtype) + if t.dim() == 2: + t = t.view(*t.shape, *([1] * conv_dim)) + ha = op(h_in_group, a) + ht = op(ha, t, **kw_dict) + hb = op(ht, b) + else: + ha = op(h_in_group, a, **kw_dict) + hb = op(ha, b) + else: + ha = op(h_in_group, a) + hb = op(ha, b) + + # Reshape and apply c (w1) + if is_conv: + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross_group = hb.transpose(1, -1) + else: + h_cross_group = hb.transpose(-1, -2) + + hc = F.linear(h_cross_group, c) + + if is_conv: + hc = hc.transpose(1, -1) + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * scale diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 3cc60bb1b..bc4260a8f 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase): rank, in_dim = mat2.shape[0], mat2.shape[1] if mid is not None: convdim = mid.ndim - 2 - layer = ( - torch.nn.Conv1d, - torch.nn.Conv2d, - torch.nn.Conv3d - )[convdim] + layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim] else: layer = torch.nn.Linear self.lora_up = layer(rank, out_dim, bias=False) @@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase): weight = w + scale * diff.reshape(w.shape) return weight.to(org_dtype) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA training: h(x) = up(down(x)) * scale + + Simple implementation using the nn.Module weights directly. + No mid/dora/reshape branches (create_train doesn't create them). + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Compute scale = alpha / rank * multiplier + scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Get weights (keep in original dtype for numerical stability) + down_weight = self.lora_down.weight + up_weight = self.lora_up.weight + + if is_conv: + # Conv path: use functional conv + # conv_dim: 1=conv1d, 2=conv2d, 3=conv3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Reshape 2D weights to conv format if needed + # down: [rank, in_features] -> [rank, in_channels, *kernel_size] + # up: [out_features, rank] -> [out_features, rank, 1, 1, ...] + if down_weight.dim() == 2: + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + if in_channels is not None: + down_weight = down_weight.view( + down_weight.shape[0], in_channels, *kernel_size + ) + else: + # Fallback: assume 1x1 kernel + down_weight = down_weight.view( + *down_weight.shape, *([1] * conv_dim) + ) + if up_weight.dim() == 2: + # up always uses 1x1 kernel + up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim)) + + # down conv uses stride/padding from module, up is 1x1 + hidden = conv_fn(x, down_weight, **kw_dict) + + # mid layer if exists (tucker decomposition) + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + if mid_weight.dim() == 2: + mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim)) + hidden = conv_fn(hidden, mid_weight) + + # up conv is always 1x1 (no stride/padding) + out = conv_fn(hidden, up_weight) + else: + # Linear path: simple matmul chain + hidden = F.linear(x, down_weight) + + # mid layer if exists + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + hidden = F.linear(hidden, mid_weight) + + out = F.linear(hidden, up_weight) + + return out * scale + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase): mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) - return LoraDiff( - (mat1, mat2, alpha, None, None, None) - ) + return LoraDiff((mat1, mat2, alpha, None, None, None)) def to_train(self): return LoraDiff(self.weights) @@ -210,3 +277,85 @@ class LoRAAdapter(WeightAdapterBase): except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA: h(x) = up(down(x)) * scale + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/locon.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape + up = v[0] + down = v[1] + alpha = v[2] + mid = v[3] + + # Compute scale = alpha / rank + rank = down.shape[0] + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + scale = scale * getattr(self, "multiplier", 1.0) + + # Cast dtype + up = up.to(dtype=x.dtype) + down = down.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + op = FUNC_LIST[ + conv_dim + 2 + ] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D weights to conv format using kernel_size + # down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size] + # up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel) + if down.dim() == 2: + # down.shape[1] = in_channels * prod(kernel_size) + if in_channels is not None: + down = down.view(down.shape[0], in_channels, *kernel_size) + else: + # Fallback: assume 1x1 kernel if in_channels unknown + down = down.view(*down.shape, *([1] * conv_dim)) + if up.dim() == 2: + # up always uses 1x1 kernel + up = up.view(*up.shape, *([1] * conv_dim)) + if mid is not None: + mid = mid.to(dtype=x.dtype) + if mid.dim() == 2: + mid = mid.view(*mid.shape, *([1] * conv_dim)) + else: + op = F.linear + kw_dict = {} # linear doesn't take stride/padding + + # Simple chain: down -> mid (if tucker) -> up + if mid is not None: + if not is_conv: + mid = mid.to(dtype=x.dtype) + hidden = op(x, down) + hidden = op(hidden, mid, **kw_dict) + out = op(hidden, up) + else: + hidden = op(x, down, **kw_dict) + out = op(hidden, up) + + return out * scale diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index c0aab9635..bc83cf8e8 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -3,13 +3,18 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + factorization, +) class OFTDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - # Unpack weights tuple from LoHaAdapter + # Unpack weights tuple from OFTAdapter blocks, rescale, alpha, _ = weights # Create trainable parameters @@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase): weight = self.rescale * weight return weight.to(org_dtype) + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + blocks = self.oft_blocks.to(device=device, dtype=dtype) + I = torch.eye(self.block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if set + if self.constraint: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r.to(dtype) + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + OFT has no additive component - returns zeros matching base_out shape. + + OFT only transforms the output via g(), it doesn't add to it. + """ + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation. + + OFT transforms output channels using block-diagonal orthogonal matrices. + """ + r = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(self.block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.reshape(*batch_shape, out_features) + + # Apply rescale if present + if self.rescaled: + rescale = self.rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out + def passive_memory_usage(self): """Calculates memory usage of the trainable parameters.""" return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] block_size, block_num = factorization(out_dim, rank) - block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32) - return OFTDiff( - (block, None, alpha, None) + block = torch.zeros( + block_num, block_size, block_size, device=weight.device, dtype=torch.float32 ) + return OFTDiff((block, None, alpha, None)) def to_train(self): return OFTDiff(self.weights) @@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase): alpha = 0 dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) block_num, block_size, *_ = blocks.shape @@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(1, 2) normed_q = q - if alpha > 0: # alpha in oft/boft is for constraint + if alpha > 0: # alpha in oft/boft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm # use float() to prevent unsupported type in .inverse() r = (I + normed_q) @ (I - normed_q).float().inverse() r = r.to(weight) + # Create I in weight's dtype for the einsum + I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype) _, *shape = weight.shape lora_diff = torch.einsum( "k n m, k n ... -> k m ...", - (r * strength) - strength * I, + (r * strength) - strength * I_w, weight.view(block_num, block_size, *shape), ).view(-1, *shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + block_num, block_size, _ = blocks.shape + I = torch.eye(block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, block_num, block_size + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation to output. + + OFT transforms the output channels using block-diagonal orthogonal matrices. + + Reference: LyCORIS DiagOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.view(*batch_shape, block_num, block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.view(*batch_shape, out_features) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out diff --git a/comfy/windows.py b/comfy/windows.py new file mode 100644 index 000000000..213dc481d --- /dev/null +++ b/comfy/windows.py @@ -0,0 +1,52 @@ +import ctypes +import logging +import psutil +from ctypes import wintypes + +import comfy_aimdo.control + +psapi = ctypes.WinDLL("psapi") +kernel32 = ctypes.WinDLL("kernel32") + +class PERFORMANCE_INFORMATION(ctypes.Structure): + _fields_ = [ + ("cb", wintypes.DWORD), + ("CommitTotal", ctypes.c_size_t), + ("CommitLimit", ctypes.c_size_t), + ("CommitPeak", ctypes.c_size_t), + ("PhysicalTotal", ctypes.c_size_t), + ("PhysicalAvailable", ctypes.c_size_t), + ("SystemCache", ctypes.c_size_t), + ("KernelTotal", ctypes.c_size_t), + ("KernelPaged", ctypes.c_size_t), + ("KernelNonpaged", ctypes.c_size_t), + ("PageSize", ctypes.c_size_t), + ("HandleCount", wintypes.DWORD), + ("ProcessCount", wintypes.DWORD), + ("ThreadCount", wintypes.DWORD), + ] + +def get_free_ram(): + #Windows is way too conservative and chalks recently used uncommitted model RAM + #as "in-use". So, calculate free RAM for the sake of general use as the greater of: + # + #1: What psutil says + #2: Total Memory - (Committed Memory - VRAM in use) + # + #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked + #commit charge for all VRAM used just incase it wants to page it all out. This just + #isn't realistic so "overcommit" on our calculations by just subtracting it off. + + pi = PERFORMANCE_INFORMATION() + pi.cb = ctypes.sizeof(pi) + + if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): + logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") + return psutil.virtual_memory().available + + committed = pi.CommitTotal * pi.PageSize + total = pi.PhysicalTotal * pi.PageSize + + return max(psutil.virtual_memory().available, + total - (committed - comfy_aimdo.control.get_total_vram_usage())) + diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index ea35c6062..1405d0b81 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -374,7 +374,7 @@ class VideoFromComponents(VideoInput): if audio_stream and self.__components.audio: waveform = self.__components.audio['waveform'] waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] - frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') + frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') frame.sample_rate = audio_sample_rate frame.pts = 0 output.mux(audio_stream.encode(frame)) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 764fa8b2b..eeea9781a 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -153,7 +153,7 @@ class Input(_IO_V3): ''' Base class for a V3 Input. ''' - def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__() self.id = id self.display_name = display_name @@ -162,6 +162,7 @@ class Input(_IO_V3): self.lazy = lazy self.extra_dict = extra_dict if extra_dict is not None else {} self.rawLink = raw_link + self.advanced = advanced def as_dict(self): return prune_dict({ @@ -170,6 +171,7 @@ class Input(_IO_V3): "tooltip": self.tooltip, "lazy": self.lazy, "rawLink": self.rawLink, + "advanced": self.advanced, }) | prune_dict(self.extra_dict) def get_io_type(self): @@ -184,8 +186,8 @@ class WidgetInput(Input): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.default = default self.socketless = socketless self.widget_type = widget_type @@ -242,8 +244,8 @@ class Boolean(ComfyTypeIO): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.label_on = label_on self.label_off = label_off self.default: bool @@ -262,8 +264,8 @@ class Int(ComfyTypeIO): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step @@ -288,8 +290,8 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step @@ -314,8 +316,8 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts @@ -350,12 +352,13 @@ class Combo(ComfyTypeIO): socketless: bool=None, extra_dict=None, raw_link: bool=None, + advanced: bool=None, ): if isinstance(options, type) and issubclass(options, Enum): options = [v.value for v in options] if isinstance(default, Enum): default = default.value - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -387,8 +390,8 @@ class MultiCombo(ComfyTypeI): class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link) + socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced) self.multiselect = True self.placeholder = placeholder self.chip = chip @@ -421,9 +424,9 @@ class Webcam(ComfyTypeIO): Type = str def __init__( self, id: str, display_name: str=None, optional=False, - tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None ): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) @comfytype(io_type="MASK") @@ -751,7 +754,7 @@ class AnyType(ComfyTypeIO): Type = Any @comfytype(io_type="MODEL_PATCH") -class MODEL_PATCH(ComfyTypeIO): +class ModelPatch(ComfyTypeIO): Type = Any @comfytype(io_type="AUDIO_ENCODER") @@ -776,7 +779,7 @@ class MultiType: ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, Input): @@ -789,7 +792,7 @@ class MultiType: # if is a widget input, make sure widget_type is set appropriately if isinstance(self.input_override, WidgetInput): self.input_override.widget_type = self.input_override.get_io_type() - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self._io_types = types @property @@ -843,8 +846,8 @@ class MatchType(ComfyTypeIO): class Input(Input): def __init__(self, id: str, template: MatchType.Template, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link) + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.template = template def as_dict(self): @@ -997,20 +1000,38 @@ class Autogrow(ComfyTypeI): names = [f"{prefix}{i}" for i in range(max)] # need to create a new input based on the contents of input template_input = None - for _, dict_input in input.items(): - # for now, get just the first value from dict_input + template_required = True + for _input_type, dict_input in input.items(): + # for now, get just the first value from dict_input; if not required, min can be ignored + if len(dict_input) == 0: + continue template_input = list(dict_input.values())[0] + template_required = _input_type == "required" + break + if template_input is None: + raise Exception("template_input could not be determined from required or optional; this should never happen.") new_dict = {} + new_dict_added_to = False + # first, add possible inputs into out_dict for i, name in enumerate(names): expected_id = finalize_prefix(curr_prefix, name) + # required + if i < min and template_required: + out_dict["required"][expected_id] = template_input + type_dict = new_dict.setdefault("required", {}) + # optional + else: + out_dict["optional"][expected_id] = template_input + type_dict = new_dict.setdefault("optional", {}) if expected_id in live_inputs: - # required - if i < min: - type_dict = new_dict.setdefault("required", {}) - # optional - else: - type_dict = new_dict.setdefault("optional", {}) + # NOTE: prefix gets added in parse_class_inputs type_dict[name] = template_input + new_dict_added_to = True + # account for the edge case that all inputs are optional and no values are received + if not new_dict_added_to: + finalized_prefix = finalize_prefix(curr_prefix) + out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix + out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_DICT parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix) @comfytype(io_type="COMFY_DYNAMICCOMBO_V3") @@ -1113,6 +1134,32 @@ class DynamicSlot(ComfyTypeI): out_dict[input_type][finalized_id] = value out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) +@comfytype(io_type="IMAGECOMPARE") +class ImageCompare(ComfyTypeI): + Type = dict + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, None, socketless, None, None, None, None, advanced) + + def as_dict(self): + return super().as_dict() + + +@comfytype(io_type="COLOR") +class Color(ComfyTypeIO): + Type = str + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, advanced: bool=None, default: str="#ffffff"): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + self.default: str + + def as_dict(self): + return super().as_dict() + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -1136,6 +1183,8 @@ class V3Data(TypedDict): 'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.' dynamic_paths: dict[str, Any] 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' + dynamic_paths_default_value: dict[str, Any] + 'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.' create_dynamic_tuple: bool 'When True, the value of the dynamic input will be in the format (value, path_key).' @@ -1199,6 +1248,7 @@ class Hidden(str, Enum): class NodeInfoV1: input: dict=None input_order: dict[str, list[str]]=None + is_input_list: bool=None output: list[str]=None output_is_list: list[bool]=None output_name: list[str]=None @@ -1212,21 +1262,74 @@ class NodeInfoV1: output_node: bool=None deprecated: bool=None experimental: bool=None + dev_only: bool=None api_node: bool=None + price_badge: dict | None = None + search_aliases: list[str]=None + @dataclass -class NodeInfoV3: - input: dict=None - output: dict=None - hidden: list[str]=None - name: str=None - display_name: str=None - description: str=None - category: str=None - output_node: bool=None - deprecated: bool=None - experimental: bool=None - api_node: bool=None +class PriceBadgeDepends: + widgets: list[str] = field(default_factory=list) + inputs: list[str] = field(default_factory=list) + input_groups: list[str] = field(default_factory=list) + + def validate(self) -> None: + if not isinstance(self.widgets, list) or any(not isinstance(x, str) for x in self.widgets): + raise ValueError("PriceBadgeDepends.widgets must be a list[str].") + if not isinstance(self.inputs, list) or any(not isinstance(x, str) for x in self.inputs): + raise ValueError("PriceBadgeDepends.inputs must be a list[str].") + if not isinstance(self.input_groups, list) or any(not isinstance(x, str) for x in self.input_groups): + raise ValueError("PriceBadgeDepends.input_groups must be a list[str].") + + def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]: + # Build lookup: widget_id -> io_type + input_types: dict[str, str] = {} + for inp in schema_inputs: + all_inputs = inp.get_all() + input_types[inp.id] = inp.get_io_type() # First input is always the parent itself + for nested_inp in all_inputs[1:]: + # For DynamicCombo/DynamicSlot, nested inputs are prefixed with parent ID + # to match frontend naming convention (e.g., "should_texture.enable_pbr") + prefixed_id = f"{inp.id}.{nested_inp.id}" + input_types[prefixed_id] = nested_inp.get_io_type() + + # Enrich widgets with type information, raising error for unknown widgets + widgets_data: list[dict[str, str]] = [] + for w in self.widgets: + if w not in input_types: + raise ValueError( + f"PriceBadge depends_on.widgets references unknown widget '{w}'. " + f"Available widgets: {list(input_types.keys())}" + ) + widgets_data.append({"name": w, "type": input_types[w]}) + + return { + "widgets": widgets_data, + "inputs": self.inputs, + "input_groups": self.input_groups, + } + + +@dataclass +class PriceBadge: + expr: str + depends_on: PriceBadgeDepends = field(default_factory=PriceBadgeDepends) + engine: str = field(default="jsonata") + + def validate(self) -> None: + if self.engine != "jsonata": + raise ValueError(f"Unsupported PriceBadge.engine '{self.engine}'. Only 'jsonata' is supported.") + if not isinstance(self.expr, str) or not self.expr.strip(): + raise ValueError("PriceBadge.expr must be a non-empty string.") + self.depends_on.validate() + + def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]: + return { + "engine": self.engine, + "depends_on": self.depends_on.as_dict(schema_inputs), + "expr": self.expr, + } @dataclass @@ -1244,6 +1347,8 @@ class Schema: hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" + search_aliases: list[str] = field(default_factory=list) + """Alternative names for search. Useful for synonyms, abbreviations, or old names after renaming.""" is_input_list: bool = False """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. @@ -1270,12 +1375,18 @@ class Schema: """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" is_experimental: bool=False """Flags a node as experimental, informing users that it may change or not work as expected.""" + is_dev_only: bool=False + """Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled.""" is_api_node: bool=False """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" + price_badge: PriceBadge | None = None + """Optional client-evaluated pricing badge declaration for this node.""" not_idempotent: bool=False """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" enable_expand: bool=False """Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" + accept_all_inputs: bool=False + """When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema.""" def validate(self): '''Validate the schema: @@ -1302,6 +1413,8 @@ class Schema: input.validate() for output in self.outputs: output.validate() + if self.price_badge is not None: + self.price_badge.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" @@ -1362,6 +1475,7 @@ class Schema: info = NodeInfoV1( input=input, input_order={key: list(value.keys()) for (key, value) in input.items()}, + is_input_list=self.is_input_list, output=output, output_is_list=output_is_list, output_name=output_name, @@ -1374,40 +1488,11 @@ class Schema: output_node=self.is_output_node, deprecated=self.is_deprecated, experimental=self.is_experimental, + dev_only=self.is_dev_only, api_node=self.is_api_node, - python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") - ) - return info - - - def get_v3_info(self, cls) -> NodeInfoV3: - input_dict = {} - output_dict = {} - hidden_list = [] - # TODO: make sure dynamic types will be handled correctly - if self.inputs: - for input in self.inputs: - add_to_dict_v3(input, input_dict) - if self.outputs: - for output in self.outputs: - add_to_dict_v3(output, output_dict) - if self.hidden: - for hidden in self.hidden: - hidden_list.append(hidden.value) - - info = NodeInfoV3( - input=input_dict, - output=output_dict, - hidden=hidden_list, - name=self.node_id, - display_name=self.display_name, - description=self.description, - category=self.category, - output_node=self.is_output_node, - deprecated=self.is_deprecated, - experimental=self.is_experimental, - api_node=self.is_api_node, - python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), + price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, + search_aliases=self.search_aliases if self.search_aliases else None, ) return info @@ -1416,6 +1501,7 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i "required": {}, "optional": {}, "dynamic_paths": {}, + "dynamic_paths_default_value": {}, } d = d.copy() # ignore hidden for parsing @@ -1425,8 +1511,12 @@ def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], i out_dict["hidden"] = hidden v3_data = {} dynamic_paths = out_dict.pop("dynamic_paths", None) - if dynamic_paths is not None: + if dynamic_paths is not None and len(dynamic_paths) > 0: v3_data["dynamic_paths"] = dynamic_paths + # this list is used for autogrow, in the case all inputs are optional and no values are passed + dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None) + if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0: + v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value return out_dict, hidden, v3_data def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: @@ -1460,14 +1550,16 @@ def add_to_dict_v1(i: Input, d: dict): as_dict.pop("optional", None) d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) -def add_to_dict_v3(io: Input | Output, d: dict): - d[io.id] = (io.get_io_type(), io.as_dict()) +class DynamicPathsDefaultValue: + EMPTY_DICT = "empty_dict" def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): paths = v3_data.get("dynamic_paths", None) + default_value_dict = v3_data.get("dynamic_paths_default_value", {}) if paths is None: return values values = values.copy() + result = {} create_tuple = v3_data.get("create_dynamic_tuple", False) @@ -1481,6 +1573,11 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): if is_last: value = values.pop(key, None) + if value is None: + # see if a default value was provided for this key + default_option = default_value_dict.get(key, None) + if default_option == DynamicPathsDefaultValue.EMPTY_DICT: + value = {} if create_tuple: value = (value, key) current[p] = value @@ -1613,13 +1710,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): # set hidden type_clone.hidden = HiddenHolder.from_v3_data(v3_data) return type_clone - - @final - @classmethod - def GET_NODE_INFO_V3(cls) -> dict[str, Any]: - schema = cls.GET_SCHEMA() - info = schema.get_v3_info(cls) - return asdict(info) ############################################# # V1 Backwards Compatibility code #-------------------------------------------- @@ -1662,6 +1752,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls.GET_SCHEMA() return cls._DEPRECATED + _DEV_ONLY = None + @final + @classproperty + def DEV_ONLY(cls): # noqa + if cls._DEV_ONLY is None: + cls.GET_SCHEMA() + return cls._DEV_ONLY + _API_NODE = None @final @classproperty @@ -1726,6 +1824,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls.GET_SCHEMA() return cls._NOT_IDEMPOTENT + _ACCEPT_ALL_INPUTS = None + @final + @classproperty + def ACCEPT_ALL_INPUTS(cls): # noqa + if cls._ACCEPT_ALL_INPUTS is None: + cls.GET_SCHEMA() + return cls._ACCEPT_ALL_INPUTS + @final @classmethod def INPUT_TYPES(cls) -> dict[str, dict]: @@ -1756,6 +1862,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls._EXPERIMENTAL = schema.is_experimental if cls._DEPRECATED is None: cls._DEPRECATED = schema.is_deprecated + if cls._DEV_ONLY is None: + cls._DEV_ONLY = schema.is_dev_only if cls._API_NODE is None: cls._API_NODE = schema.is_api_node if cls._OUTPUT_NODE is None: @@ -1764,6 +1872,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls._INPUT_IS_LIST = schema.is_input_list if cls._NOT_IDEMPOTENT is None: cls._NOT_IDEMPOTENT = schema.not_idempotent + if cls._ACCEPT_ALL_INPUTS is None: + cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs if cls._RETURN_TYPES is None: output = [] @@ -1911,6 +2021,7 @@ __all__ = [ "ControlNet", "Vae", "Model", + "ModelPatch", "ClipVision", "ClipVisionOutput", "AudioEncoder", @@ -1943,6 +2054,7 @@ __all__ = [ "AnyType", "MultiType", "Tracks", + "Color", # Dynamic Types "MatchType", "DynamicCombo", @@ -1951,11 +2063,12 @@ __all__ = [ "HiddenHolder", "Hidden", "NodeInfoV1", - "NodeInfoV3", "Schema", "ComfyNode", "NodeOutput", "add_to_dict_v1", - "add_to_dict_v3", "V3Data", + "ImageCompare", + "PriceBadgeDepends", + "PriceBadge", ] diff --git a/comfy_api_nodes/README.md b/comfy_api_nodes/README.md deleted file mode 100644 index f56d6c860..000000000 --- a/comfy_api_nodes/README.md +++ /dev/null @@ -1,65 +0,0 @@ -# ComfyUI API Nodes - -## Introduction - -Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview). - -## Development - -While developing, you should be testing against the Staging environment. To test against staging: - -**Install ComfyUI_frontend** - -Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication. - -> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication. - -```bash -python run main.py --comfy-api-base https://stagingapi.comfy.org -``` - -To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging. - -API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes. - -### Redocly Instructions - -**Tip** -When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet. - -Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging. - -```bash -# Download the OpenAPI file from staging server. -curl -o openapi.yaml https://stagingapi.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` - - -# Merging to Master - -Before merging to comfyanonymous/ComfyUI master, follow these steps: - -1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes. -1. Make sure the ComfyUI API is deployed to prod with your changes. -1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file. - -```bash -# Download the OpenAPI file from prod server. -curl -o openapi.yaml https://api.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl.py similarity index 100% rename from comfy_api_nodes/apis/bfl_api.py rename to comfy_api_nodes/apis/bfl.py diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py new file mode 100644 index 000000000..9119cacc6 --- /dev/null +++ b/comfy_api_nodes/apis/bria.py @@ -0,0 +1,61 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + + +class InputModerationSettings(TypedDict): + prompt_content_moderation: bool + visual_input_moderation: bool + visual_output_moderation: bool + + +class BriaEditImageRequest(BaseModel): + instruction: str | None = Field(...) + structured_instruction: str | None = Field( + ..., + description="Use this instead of instruction for precise, programmatic control.", + ) + images: list[str] = Field( + ..., + description="Required. Publicly available URL or Base64-encoded. Must contain exactly one item.", + ) + mask: str | None = Field( + None, + description="Mask image (black and white). Black areas will be preserved, white areas will be edited. " + "If omitted, the edit applies to the entire image. " + "The input image and the the input mask must be of the same size.", + ) + negative_prompt: str | None = Field(None) + guidance_scale: float = Field(...) + model_version: str = Field(...) + steps_num: int = Field(...) + seed: int = Field(...) + ip_signal: bool = Field( + False, + description="If true, returns a warning for potential IP content in the instruction.", + ) + prompt_content_moderation: bool = Field( + False, description="If true, returns 422 on instruction moderation failure." + ) + visual_input_content_moderation: bool = Field( + False, description="If true, returns 422 on images or mask moderation failure." + ) + visual_output_content_moderation: bool = Field( + False, description="If true, returns 422 on visual output moderation failure." + ) + + +class BriaStatusResponse(BaseModel): + request_id: str = Field(...) + status_url: str = Field(...) + warning: str | None = Field(None) + + +class BriaResult(BaseModel): + structured_prompt: str = Field(...) + image_url: str = Field(...) + + +class BriaResponse(BaseModel): + status: str = Field(...) + result: BriaResult | None = Field(None) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance.py similarity index 89% rename from comfy_api_nodes/apis/bytedance_api.py rename to comfy_api_nodes/apis/bytedance.py index b8c2f618b..23cbe2372 100644 --- a/comfy_api_nodes/apis/bytedance_api.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -13,17 +13,6 @@ class Text2ImageTaskCreationRequest(BaseModel): watermark: bool | None = Field(False) -class Image2ImageTaskCreationRequest(BaseModel): - model: str = Field(...) - prompt: str = Field(...) - response_format: str | None = Field("url") - image: str = Field(..., description="Base64 encoded string or image URL") - size: str | None = Field("adaptive") - seed: int | None = Field(..., ge=0, le=2147483647) - guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(False) - - class Seedream4Options(BaseModel): max_images: int = Field(15) @@ -65,11 +54,13 @@ class TaskImageContent(BaseModel): class Text2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent] = Field(..., min_length=1) + generate_audio: bool | None = Field(...) class Image2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + generate_audio: bool | None = Field(...) class TaskCreationResponse(BaseModel): @@ -141,4 +132,9 @@ VIDEO_TASKS_EXECUTION_TIME = { "720p": 65, "1080p": 100, }, + "seedance-1-5-pro-251215": { + "480p": 80, + "720p": 100, + "1080p": 150, + }, } diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini.py similarity index 100% rename from comfy_api_nodes/apis/gemini_api.py rename to comfy_api_nodes/apis/gemini.py diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py new file mode 100644 index 000000000..8e3c79ab9 --- /dev/null +++ b/comfy_api_nodes/apis/grok.py @@ -0,0 +1,67 @@ +from pydantic import BaseModel, Field + + +class ImageGenerationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + aspect_ratio: str = Field(...) + n: int = Field(...) + seed: int = Field(...) + response_for: str = Field("url") + + +class InputUrlObject(BaseModel): + url: str = Field(...) + + +class ImageEditRequest(BaseModel): + model: str = Field(...) + image: InputUrlObject = Field(...) + prompt: str = Field(...) + resolution: str = Field(...) + n: int = Field(...) + seed: int = Field(...) + response_for: str = Field("url") + + +class VideoGenerationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + image: InputUrlObject | None = Field(...) + duration: int = Field(...) + aspect_ratio: str | None = Field(...) + resolution: str = Field(...) + seed: int = Field(...) + + +class VideoEditRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + video: InputUrlObject = Field(...) + seed: int = Field(...) + + +class ImageResponseObject(BaseModel): + url: str | None = Field(None) + b64_json: str | None = Field(None) + revised_prompt: str | None = Field(None) + + +class ImageGenerationResponse(BaseModel): + data: list[ImageResponseObject] = Field(...) + + +class VideoGenerationResponse(BaseModel): + request_id: str = Field(...) + + +class VideoResponseObject(BaseModel): + url: str = Field(...) + upsampled_prompt: str | None = Field(None) + duration: int = Field(...) + + +class VideoStatusResponse(BaseModel): + status: str | None = Field(None) + video: VideoResponseObject | None = Field(None) + model: str | None = Field(None) diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py new file mode 100644 index 000000000..6421c9bd5 --- /dev/null +++ b/comfy_api_nodes/apis/hunyuan3d.py @@ -0,0 +1,66 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field, model_validator + + +class InputGenerateType(TypedDict): + generate_type: str + polygon_type: str + pbr: bool + + +class Hunyuan3DViewImage(BaseModel): + ViewType: str = Field(..., description="Valid values: back, left, right.") + ViewImageUrl: str = Field(...) + + +class To3DProTaskRequest(BaseModel): + Model: str = Field(...) + Prompt: str | None = Field(None) + ImageUrl: str | None = Field(None) + MultiViewImages: list[Hunyuan3DViewImage] | None = Field(None) + EnablePBR: bool | None = Field(...) + FaceCount: int | None = Field(...) + GenerateType: str | None = Field(...) + PolygonType: str | None = Field(...) + + +class RequestError(BaseModel): + Code: str = Field("") + Message: str = Field("") + + +class To3DProTaskCreateResponse(BaseModel): + JobId: str | None = Field(None) + Error: RequestError | None = Field(None) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "Response" in values and isinstance(values["Response"], dict): + return values["Response"] + return values + + +class ResultFile3D(BaseModel): + Type: str = Field(...) + Url: str = Field(...) + PreviewImageUrl: str = Field("") + + +class To3DProTaskResultResponse(BaseModel): + ErrorCode: str = Field("") + ErrorMessage: str = Field("") + ResultFile3Ds: list[ResultFile3D] = Field([]) + Status: str = Field(...) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "Response" in values and isinstance(values["Response"], dict): + return values["Response"] + return values + + +class To3DProTaskQueryRequest(BaseModel): + JobId: str = Field(...) diff --git a/comfy_api_nodes/apis/ideogram.py b/comfy_api_nodes/apis/ideogram.py new file mode 100644 index 000000000..737e18e3b --- /dev/null +++ b/comfy_api_nodes/apis/ideogram.py @@ -0,0 +1,292 @@ +from enum import Enum +from typing import Optional, List, Dict, Any, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel, StrictBytes + + +class IdeogramColorPalette1(BaseModel): + name: str = Field(..., description='Name of the preset color palette') + + +class Member(BaseModel): + color: Optional[str] = Field( + None, description='Hexadecimal color code', pattern='^#[0-9A-Fa-f]{6}$' + ) + weight: Optional[float] = Field( + None, description='Optional weight for the color (0-1)', ge=0.0, le=1.0 + ) + + +class IdeogramColorPalette2(BaseModel): + members: List[Member] = Field( + ..., description='Array of color definitions with optional weights' + ) + + +class IdeogramColorPalette( + RootModel[Union[IdeogramColorPalette1, IdeogramColorPalette2]] +): + root: Union[IdeogramColorPalette1, IdeogramColorPalette2] = Field( + ..., + description='A color palette specification that can either use a preset name or explicit color definitions with weights', + ) + + +class ImageRequest(BaseModel): + aspect_ratio: Optional[str] = Field( + None, + description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.", + ) + color_palette: Optional[Dict[str, Any]] = Field( + None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' + ) + magic_prompt_option: Optional[str] = Field( + None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')." + ) + model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") + negative_prompt: Optional[str] = Field( + None, + description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', + ) + num_images: Optional[int] = Field( + 1, + description='Optional. Number of images to generate (1-8). Defaults to 1.', + ge=1, + le=8, + ) + prompt: str = Field( + ..., description='Required. The prompt to use to generate the image.' + ) + resolution: Optional[str] = Field( + None, + description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", + ) + seed: Optional[int] = Field( + None, + description='Optional. A number between 0 and 2147483647.', + ge=0, + le=2147483647, + ) + style_type: Optional[str] = Field( + None, + description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.", + ) + + +class IdeogramGenerateRequest(BaseModel): + image_request: ImageRequest = Field( + ..., description='The image generation request parameters.' + ) + + +class Datum(BaseModel): + is_image_safe: Optional[bool] = Field( + None, description='Indicates whether the image is considered safe.' + ) + prompt: Optional[str] = Field( + None, description='The prompt used to generate this image.' + ) + resolution: Optional[str] = Field( + None, description="The resolution of the generated image (e.g., '1024x1024')." + ) + seed: Optional[int] = Field( + None, description='The seed value used for this generation.' + ) + style_type: Optional[str] = Field( + None, + description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').", + ) + url: Optional[str] = Field(None, description='URL to the generated image.') + + +class IdeogramGenerateResponse(BaseModel): + created: Optional[datetime] = Field( + None, description='Timestamp when the generation was created.' + ) + data: Optional[List[Datum]] = Field( + None, description='Array of generated image information.' + ) + + +class StyleCode(RootModel[str]): + root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$') + + +class Datum1(BaseModel): + is_image_safe: Optional[bool] = None + prompt: Optional[str] = None + resolution: Optional[str] = None + seed: Optional[int] = None + style_type: Optional[str] = None + url: Optional[str] = None + + +class IdeogramV3IdeogramResponse(BaseModel): + created: Optional[datetime] = None + data: Optional[List[Datum1]] = None + + +class RenderingSpeed1(str, Enum): + TURBO = 'TURBO' + DEFAULT = 'DEFAULT' + QUALITY = 'QUALITY' + + +class IdeogramV3ReframeRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + rendering_speed: Optional[RenderingSpeed1] = None + resolution: str + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class MagicPrompt(str, Enum): + AUTO = 'AUTO' + ON = 'ON' + OFF = 'OFF' + + +class StyleType(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + + +class IdeogramV3RemixRequest(BaseModel): + aspect_ratio: Optional[str] = None + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + image_weight: Optional[int] = Field(50, ge=1, le=100) + magic_prompt: Optional[MagicPrompt] = None + negative_prompt: Optional[str] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + resolution: Optional[str] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + style_type: Optional[StyleType] = None + + +class IdeogramV3ReplaceBackgroundRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + magic_prompt: Optional[MagicPrompt] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class ColorPalette(BaseModel): + name: str = Field(..., description='Name of the color palette', examples=['PASTEL']) + + +class MagicPrompt2(str, Enum): + ON = 'ON' + OFF = 'OFF' + + +class StyleType1(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + FICTION = 'FICTION' + + +class RenderingSpeed(str, Enum): + DEFAULT = 'DEFAULT' + TURBO = 'TURBO' + QUALITY = 'QUALITY' + + +class IdeogramV3EditRequest(BaseModel): + color_palette: Optional[IdeogramColorPalette] = None + image: Optional[StrictBytes] = Field( + None, + description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.', + ) + magic_prompt: Optional[str] = Field( + None, + description='Determine if MagicPrompt should be used in generating the request or not.', + ) + mask: Optional[StrictBytes] = Field( + None, + description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.', + ) + num_images: Optional[int] = Field( + None, description='The number of images to generate.' + ) + prompt: str = Field( + ..., description='The prompt used to describe the edited result.' + ) + rendering_speed: RenderingSpeed + seed: Optional[int] = Field( + None, description='Random seed. Set for reproducible generation.' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, + description='A list of 8 character hexadecimal codes representing the style of the image. Cannot be used in conjunction with style_reference_images or style_type.', + ) + style_reference_images: Optional[List[StrictBytes]] = Field( + None, + description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.', + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) + + +class IdeogramV3Request(BaseModel): + aspect_ratio: Optional[str] = Field( + None, description='Aspect ratio in format WxH', examples=['1x3'] + ) + color_palette: Optional[ColorPalette] = None + magic_prompt: Optional[MagicPrompt2] = Field( + None, description='Whether to enable magic prompt enhancement' + ) + negative_prompt: Optional[str] = Field( + None, description='Text prompt specifying what to avoid in the generation' + ) + num_images: Optional[int] = Field( + None, description='Number of images to generate', ge=1 + ) + prompt: str = Field(..., description='The text prompt for image generation') + rendering_speed: RenderingSpeed + resolution: Optional[str] = Field( + None, description='Image resolution in format WxH', examples=['1280x800'] + ) + seed: Optional[int] = Field( + None, description='Seed value for reproducible generation' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, description='Array of style codes in hexadecimal format' + ) + style_reference_images: Optional[List[str]] = Field( + None, description='Array of reference image URLs or identifiers' + ) + style_type: Optional[StyleType1] = Field( + None, description='The type of style to apply' + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling.py similarity index 100% rename from comfy_api_nodes/apis/kling_api.py rename to comfy_api_nodes/apis/kling.py diff --git a/comfy_api_nodes/apis/luma_api.py b/comfy_api_nodes/apis/luma.py similarity index 100% rename from comfy_api_nodes/apis/luma_api.py rename to comfy_api_nodes/apis/luma.py diff --git a/comfy_api_nodes/apis/magnific.py b/comfy_api_nodes/apis/magnific.py new file mode 100644 index 000000000..b9f148def --- /dev/null +++ b/comfy_api_nodes/apis/magnific.py @@ -0,0 +1,122 @@ +from typing import TypedDict + +from pydantic import AliasChoices, BaseModel, Field, model_validator + + +class InputPortraitMode(TypedDict): + portrait_mode: str + portrait_style: str + portrait_beautifier: str + + +class InputAdvancedSettings(TypedDict): + advanced_settings: str + whites: int + blacks: int + brightness: int + contrast: int + saturation: int + engine: str + transfer_light_a: str + transfer_light_b: str + fixed_generation: bool + + +class InputSkinEnhancerMode(TypedDict): + mode: str + skin_detail: int + optimized_for: str + + +class ImageUpscalerCreativeRequest(BaseModel): + image: str = Field(...) + scale_factor: str = Field(...) + optimized_for: str = Field(...) + prompt: str | None = Field(None) + creativity: int = Field(...) + hdr: int = Field(...) + resemblance: int = Field(...) + fractality: int = Field(...) + engine: str = Field(...) + + +class ImageUpscalerPrecisionV2Request(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + ultra_detail: int = Field(...) + flavor: str = Field(...) + scale_factor: int = Field(...) + + +class ImageRelightAdvancedSettingsRequest(BaseModel): + whites: int = Field(...) + blacks: int = Field(...) + brightness: int = Field(...) + contrast: int = Field(...) + saturation: int = Field(...) + engine: str = Field(...) + transfer_light_a: str = Field(...) + transfer_light_b: str = Field(...) + fixed_generation: bool = Field(...) + + +class ImageRelightRequest(BaseModel): + image: str = Field(...) + prompt: str | None = Field(None) + transfer_light_from_reference_image: str | None = Field(None) + light_transfer_strength: int = Field(...) + interpolate_from_original: bool = Field(...) + change_background: bool = Field(...) + style: str = Field(...) + preserve_details: bool = Field(...) + advanced_settings: ImageRelightAdvancedSettingsRequest | None = Field(...) + + +class ImageStyleTransferRequest(BaseModel): + image: str = Field(...) + reference_image: str = Field(...) + prompt: str | None = Field(None) + style_strength: int = Field(...) + structure_strength: int = Field(...) + is_portrait: bool = Field(...) + portrait_style: str | None = Field(...) + portrait_beautifier: str | None = Field(...) + flavor: str = Field(...) + engine: str = Field(...) + fixed_generation: bool = Field(...) + + +class ImageSkinEnhancerCreativeRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + + +class ImageSkinEnhancerFaithfulRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + skin_detail: int = Field(...) + + +class ImageSkinEnhancerFlexibleRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + optimized_for: str = Field(...) + + +class TaskResponse(BaseModel): + """Unified response model that handles both wrapped and unwrapped API responses.""" + + task_id: str = Field(...) + status: str = Field(validation_alias=AliasChoices("status", "task_status")) + generated: list[str] | None = Field(None) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "data" in values and isinstance(values["data"], dict): + return values["data"] + return values diff --git a/comfy_api_nodes/apis/meshy.py b/comfy_api_nodes/apis/meshy.py new file mode 100644 index 000000000..be46d0d58 --- /dev/null +++ b/comfy_api_nodes/apis/meshy.py @@ -0,0 +1,160 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + +from comfy_api.latest import Input + + +class InputShouldRemesh(TypedDict): + should_remesh: str + topology: str + target_polycount: int + + +class InputShouldTexture(TypedDict): + should_texture: str + enable_pbr: bool + texture_prompt: str + texture_image: Input.Image | None + + +class MeshyTaskResponse(BaseModel): + result: str = Field(...) + + +class MeshyTextToModelRequest(BaseModel): + mode: str = Field("preview") + prompt: str = Field(..., max_length=600) + art_style: str = Field(..., description="'realistic' or 'sculpture'") + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + pose_mode: str = Field(...) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyRefineTask(BaseModel): + mode: str = Field("refine") + preview_task_id: str = Field(...) + enable_pbr: bool | None = Field(...) + texture_prompt: str | None = Field(...) + texture_image_url: str | None = Field(...) + ai_model: str = Field(...) + moderation: bool = Field(False) + + +class MeshyImageToModelRequest(BaseModel): + image_url: str = Field(...) + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + should_texture: bool = Field(...) + enable_pbr: bool | None = Field(...) + pose_mode: str = Field(...) + texture_prompt: str | None = Field(None, max_length=600) + texture_image_url: str | None = Field(None) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyMultiImageToModelRequest(BaseModel): + image_urls: list[str] = Field(...) + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + should_texture: bool = Field(...) + enable_pbr: bool | None = Field(...) + pose_mode: str = Field(...) + texture_prompt: str | None = Field(None, max_length=600) + texture_image_url: str | None = Field(None) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyRiggingRequest(BaseModel): + input_task_id: str = Field(...) + height_meters: float = Field(...) + texture_image_url: str | None = Field(...) + + +class MeshyAnimationRequest(BaseModel): + rig_task_id: str = Field(...) + action_id: int = Field(...) + + +class MeshyTextureRequest(BaseModel): + input_task_id: str = Field(...) + ai_model: str = Field(...) + enable_original_uv: bool = Field(...) + enable_pbr: bool = Field(...) + text_style_prompt: str | None = Field(...) + image_style_url: str | None = Field(...) + + +class MeshyModelsUrls(BaseModel): + glb: str = Field("") + + +class MeshyRiggedModelsUrls(BaseModel): + rigged_character_glb_url: str = Field("") + + +class MeshyAnimatedModelsUrls(BaseModel): + animation_glb_url: str = Field("") + + +class MeshyResultTextureUrls(BaseModel): + base_color: str = Field(...) + metallic: str | None = Field(None) + normal: str | None = Field(None) + roughness: str | None = Field(None) + + +class MeshyTaskError(BaseModel): + message: str | None = Field(None) + + +class MeshyModelResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + model_urls: MeshyModelsUrls = Field(MeshyModelsUrls()) + thumbnail_url: str = Field(...) + video_url: str | None = Field(None) + status: str = Field(...) + progress: int = Field(0) + texture_urls: list[MeshyResultTextureUrls] | None = Field([]) + task_error: MeshyTaskError | None = Field(None) + + +class MeshyRiggedResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + status: str = Field(...) + progress: int = Field(0) + result: MeshyRiggedModelsUrls = Field(MeshyRiggedModelsUrls()) + task_error: MeshyTaskError | None = Field(None) + + +class MeshyAnimationResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + status: str = Field(...) + progress: int = Field(0) + result: MeshyAnimatedModelsUrls = Field(MeshyAnimatedModelsUrls()) + task_error: MeshyTaskError | None = Field(None) diff --git a/comfy_api_nodes/apis/minimax_api.py b/comfy_api_nodes/apis/minimax.py similarity index 100% rename from comfy_api_nodes/apis/minimax_api.py rename to comfy_api_nodes/apis/minimax.py diff --git a/comfy_api_nodes/apis/moonvalley.py b/comfy_api_nodes/apis/moonvalley.py new file mode 100644 index 000000000..7ec7a4ade --- /dev/null +++ b/comfy_api_nodes/apis/moonvalley.py @@ -0,0 +1,152 @@ +from enum import Enum +from typing import Optional, Dict, Any + +from pydantic import BaseModel, Field, StrictBytes + + +class MoonvalleyPromptResponse(BaseModel): + error: Optional[Dict[str, Any]] = None + frame_conditioning: Optional[Dict[str, Any]] = None + id: Optional[str] = None + inference_params: Optional[Dict[str, Any]] = None + meta: Optional[Dict[str, Any]] = None + model_params: Optional[Dict[str, Any]] = None + output_url: Optional[str] = None + prompt_text: Optional[str] = None + status: Optional[str] = None + + +class MoonvalleyTextToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 75, description='Number of cooldown steps (calculated based on num_frames)' + ) + fps: Optional[int] = Field( + 24, description='Frames per second of the generated video' + ) + guidance_scale: Optional[float] = Field( + 10, description='Guidance scale for generation control' + ) + height: Optional[int] = Field( + 1080, description='Height of the generated video in pixels' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + num_frames: Optional[int] = Field(64, description='Number of frames to generate') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 0, description='Number of warmup steps (calculated based on num_frames)' + ) + width: Optional[int] = Field( + 1920, description='Width of the generated video in pixels' + ) + + +class MoonvalleyTextToVideoRequest(BaseModel): + image_url: Optional[str] = None + inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None + prompt_text: Optional[str] = None + webhook_url: Optional[str] = None + + +class MoonvalleyUploadFileRequest(BaseModel): + file: Optional[StrictBytes] = None + + +class MoonvalleyUploadFileResponse(BaseModel): + access_url: Optional[str] = None + + +class MoonvalleyVideoToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 36, description='Number of cooldown steps (calculated based on num_frames)' + ) + guidance_scale: Optional[float] = Field( + 15, description='Guidance scale for generation control' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 24, description='Number of warmup steps (calculated based on num_frames)' + ) + + +class ControlType(str, Enum): + motion_control = 'motion_control' + pose_control = 'pose_control' + + +class MoonvalleyVideoToVideoRequest(BaseModel): + control_type: ControlType = Field( + ..., description='Supported types for video control' + ) + inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None + prompt_text: str = Field(..., description='Describes the video to generate') + video_url: str = Field(..., description='Url to control video') + webhook_url: Optional[str] = Field( + None, description='Optional webhook URL for notifications' + ) diff --git a/comfy_api_nodes/apis/openai.py b/comfy_api_nodes/apis/openai.py new file mode 100644 index 000000000..b85ef252b --- /dev/null +++ b/comfy_api_nodes/apis/openai.py @@ -0,0 +1,170 @@ +from pydantic import BaseModel, Field + + +class Datum2(BaseModel): + b64_json: str | None = Field(None, description="Base64 encoded image data") + revised_prompt: str | None = Field(None, description="Revised prompt") + url: str | None = Field(None, description="URL of the image") + + +class InputTokensDetails(BaseModel): + image_tokens: int | None = Field(None) + text_tokens: int | None = Field(None) + + +class Usage(BaseModel): + input_tokens: int | None = Field(None) + input_tokens_details: InputTokensDetails | None = Field(None) + output_tokens: int | None = Field(None) + total_tokens: int | None = Field(None) + + +class OpenAIImageGenerationResponse(BaseModel): + data: list[Datum2] | None = Field(None) + usage: Usage | None = Field(None) + + +class OpenAIImageEditRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str = Field(...) + moderation: str | None = Field(None) + n: int | None = Field(None, description="The number of images to generate") + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + size: str | None = Field(None, description="Size of the output image") + + +class OpenAIImageGenerationRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str | None = Field(None) + moderation: str | None = Field(None) + n: int | None = Field( + None, + description="The number of images to generate.", + ) + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="The quality of the generated image") + size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + style: str | None = Field(None, description="Style of the image (only for dall-e-3)") + + +class ModelResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0) + top_p: float | None = Field( + 1, + description="Controls diversity of the response via nucleus sampling", + ge=0.0, + le=1.0, + ) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + previous_response_id: str | None = Field(None) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class OutputTokensDetails(BaseModel): + reasoning_tokens: int = Field(..., description="The number of reasoning tokens.") + + +class CachedTokensDetails(BaseModel): + cached_tokens: int = Field( + ..., + description="The number of tokens that were retrieved from the cache.", + ) + + +class ResponseUsage(BaseModel): + input_tokens: int = Field(..., description="The number of input tokens.") + input_tokens_details: CachedTokensDetails = Field(...) + output_tokens: int = Field(..., description="The number of output tokens.") + output_tokens_details: OutputTokensDetails = Field(...) + total_tokens: int = Field(..., description="The total number of tokens used.") + + +class InputTextContent(BaseModel): + text: str = Field(..., description="The text input to the model.") + type: str = Field("input_text") + + +class OutputContent(BaseModel): + type: str = Field(..., description="The type of output content") + text: str | None = Field(None, description="The text content") + data: str | None = Field(None, description="Base64-encoded audio data") + transcript: str | None = Field(None, description="Transcript of the audio") + + +class OutputMessage(BaseModel): + type: str = Field(..., description="The type of output item") + content: list[OutputContent] | None = Field(None, description="The content of the message") + role: str | None = Field(None, description="The role of the message") + + +class OpenAIResponse(ModelResponseProperties, ResponseProperties): + created_at: float | None = Field( + None, + description="Unix timestamp (in seconds) of when this Response was created.", + ) + error: ResponseError | None = Field(None) + id: str | None = Field(None, description="Unique identifier for this Response.") + object: str | None = Field(None, description="The object type of this resource - always set to `response`.") + output: list[OutputMessage] | None = Field(None) + parallel_tool_calls: bool | None = Field(True) + status: str | None = Field( + None, + description="One of `completed`, `failed`, `in_progress`, or `incomplete`.", + ) + usage: ResponseUsage | None = Field(None) + + +class InputImageContent(BaseModel): + detail: str = Field(..., description="One of `high`, `low`, or `auto`. Defaults to `auto`.") + file_id: str | None = Field(None) + image_url: str | None = Field(None) + type: str = Field(..., description="The type of the input item. Always `input_image`.") + + +class InputFileContent(BaseModel): + file_data: str | None = Field(None) + file_id: str | None = Field(None) + filename: str | None = Field(None, description="The name of the file to be sent to the model.") + type: str = Field(..., description="The type of the input item. Always `input_file`.") + + +class InputMessage(BaseModel): + content: list[InputTextContent | InputImageContent | InputFileContent] = Field( + ..., + description="A list of one or many input items to the model, containing different content types.", + ) + role: str | None = Field(None) + type: str | None = Field(None) + + +class OpenAICreateResponse(ModelResponseProperties, ResponseProperties): + include: str | None = Field(None) + input: list[InputMessage] = Field(...) + parallel_tool_calls: bool | None = Field( + True, description="Whether to allow the model to run tool calls in parallel." + ) + store: bool | None = Field( + True, + description="Whether to store the generated model response for later retrieval via API.", + ) + stream: bool | None = Field(False) + usage: ResponseUsage | None = Field(None) diff --git a/comfy_api_nodes/apis/openai_api.py b/comfy_api_nodes/apis/openai_api.py deleted file mode 100644 index ae5bb2673..000000000 --- a/comfy_api_nodes/apis/openai_api.py +++ /dev/null @@ -1,52 +0,0 @@ -from pydantic import BaseModel, Field - - -class Datum2(BaseModel): - b64_json: str | None = Field(None, description="Base64 encoded image data") - revised_prompt: str | None = Field(None, description="Revised prompt") - url: str | None = Field(None, description="URL of the image") - - -class InputTokensDetails(BaseModel): - image_tokens: int | None = None - text_tokens: int | None = None - - -class Usage(BaseModel): - input_tokens: int | None = None - input_tokens_details: InputTokensDetails | None = None - output_tokens: int | None = None - total_tokens: int | None = None - - -class OpenAIImageGenerationResponse(BaseModel): - data: list[Datum2] | None = None - usage: Usage | None = None - - -class OpenAIImageEditRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str = Field(...) - moderation: str | None = Field(None) - n: int | None = Field(None, description="The number of images to generate") - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - size: str | None = Field(None, description="Size of the output image") - - -class OpenAIImageGenerationRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str | None = Field(None) - moderation: str | None = Field(None) - n: int | None = Field( - None, - description="The number of images to generate.", - ) - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="The quality of the generated image") - size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - style: str | None = Field(None, description="Style of the image (only for dall-e-3)") diff --git a/comfy_api_nodes/apis/pixverse_api.py b/comfy_api_nodes/apis/pixverse.py similarity index 100% rename from comfy_api_nodes/apis/pixverse_api.py rename to comfy_api_nodes/apis/pixverse.py diff --git a/comfy_api_nodes/apis/recraft_api.py b/comfy_api_nodes/apis/recraft.py similarity index 74% rename from comfy_api_nodes/apis/recraft_api.py rename to comfy_api_nodes/apis/recraft.py index c36d95f24..0bd7d23b3 100644 --- a/comfy_api_nodes/apis/recraft_api.py +++ b/comfy_api_nodes/apis/recraft.py @@ -1,11 +1,8 @@ from __future__ import annotations - - from enum import Enum -from typing import Optional -from pydantic import BaseModel, Field, conint, confloat +from pydantic import BaseModel, Field class RecraftColor: @@ -229,24 +226,24 @@ class RecraftColorObject(BaseModel): class RecraftControlsObject(BaseModel): - colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors') - background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color') - no_text: Optional[bool] = Field(None, description='Do not embed text layouts') - artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].') + colors: list[RecraftColorObject] | None = Field(None, description='An array of preferable colors') + background_color: RecraftColorObject | None = Field(None, description='Use given color as a desired background color') + no_text: bool | None = Field(None, description='Do not embed text layouts') + artistic_level: int | None = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].') class RecraftImageGenerationRequest(BaseModel): prompt: str = Field(..., description='The text prompt describing the image to generate') - size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")') - n: conint(ge=1, le=6) = Field(..., description='The number of images to generate') - negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image') - model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")') - style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")') - substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input') - controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process') - style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID') - strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity') - random_seed: Optional[int] = Field(None, description="Seed for video generation") + size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")') + n: int = Field(..., description='The number of images to generate') + negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image') + model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")') + style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")') + substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input') + controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process') + style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID') + strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity') + random_seed: int | None = Field(None, description="Seed for video generation") # text_layout @@ -258,5 +255,13 @@ class RecraftReturnedObject(BaseModel): class RecraftImageGenerationResponse(BaseModel): created: int = Field(..., description='Unix timestamp when the generation was created') credits: int = Field(..., description='Number of credits used for the generation') - data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information') - image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image') + data: list[RecraftReturnedObject] | None = Field(None, description='Array of generated image information') + image: RecraftReturnedObject | None = Field(None, description='Single generated image') + + +class RecraftCreateStyleRequest(BaseModel): + style: str = Field(..., description="realistic_image, digital_illustration, vector_illustration, or icon") + + +class RecraftCreateStyleResponse(BaseModel): + id: str = Field(..., description="UUID of the created style") diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin.py similarity index 100% rename from comfy_api_nodes/apis/rodin_api.py rename to comfy_api_nodes/apis/rodin.py diff --git a/comfy_api_nodes/apis/runway.py b/comfy_api_nodes/apis/runway.py new file mode 100644 index 000000000..df6f2b845 --- /dev/null +++ b/comfy_api_nodes/apis/runway.py @@ -0,0 +1,127 @@ +from enum import Enum +from typing import Optional, List, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel + + +class RunwayAspectRatioEnum(str, Enum): + field_1280_720 = '1280:720' + field_720_1280 = '720:1280' + field_1104_832 = '1104:832' + field_832_1104 = '832:1104' + field_960_960 = '960:960' + field_1584_672 = '1584:672' + field_1280_768 = '1280:768' + field_768_1280 = '768:1280' + + +class Position(str, Enum): + first = 'first' + last = 'last' + + +class RunwayPromptImageDetailedObject(BaseModel): + position: Position = Field( + ..., + description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.", + ) + uri: str = Field( + ..., description='A HTTPS URL or data URI containing an encoded image.' + ) + + +class RunwayPromptImageObject( + RootModel[Union[str, List[RunwayPromptImageDetailedObject]]] +): + root: Union[str, List[RunwayPromptImageDetailedObject]] = Field( + ..., + description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.', + ) + + +class RunwayModelEnum(str, Enum): + gen4_turbo = 'gen4_turbo' + gen3a_turbo = 'gen3a_turbo' + + +class RunwayDurationEnum(int, Enum): + integer_5 = 5 + integer_10 = 10 + + +class RunwayImageToVideoRequest(BaseModel): + duration: RunwayDurationEnum + model: RunwayModelEnum + promptImage: RunwayPromptImageObject + promptText: Optional[str] = Field( + None, description='Text prompt for the generation', max_length=1000 + ) + ratio: RunwayAspectRatioEnum + seed: int = Field( + ..., description='Random seed for generation', ge=0, le=4294967295 + ) + + +class RunwayImageToVideoResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') + + +class RunwayTaskStatusEnum(str, Enum): + SUCCEEDED = 'SUCCEEDED' + RUNNING = 'RUNNING' + FAILED = 'FAILED' + PENDING = 'PENDING' + CANCELLED = 'CANCELLED' + THROTTLED = 'THROTTLED' + + +class RunwayTaskStatusResponse(BaseModel): + createdAt: datetime = Field(..., description='Task creation timestamp') + id: str = Field(..., description='Task ID') + output: Optional[List[str]] = Field(None, description='Array of output video URLs') + progress: Optional[float] = Field( + None, + description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.', + ge=0.0, + le=1.0, + ) + status: RunwayTaskStatusEnum + + +class Model4(str, Enum): + gen4_image = 'gen4_image' + + +class ReferenceImage(BaseModel): + uri: Optional[str] = Field( + None, description='A HTTPS URL or data URI containing an encoded image' + ) + + +class RunwayTextToImageAspectRatioEnum(str, Enum): + field_1920_1080 = '1920:1080' + field_1080_1920 = '1080:1920' + field_1024_1024 = '1024:1024' + field_1360_768 = '1360:768' + field_1080_1080 = '1080:1080' + field_1168_880 = '1168:880' + field_1440_1080 = '1440:1080' + field_1080_1440 = '1080:1440' + field_1808_768 = '1808:768' + field_2112_912 = '2112:912' + + +class RunwayTextToImageRequest(BaseModel): + model: Model4 = Field(..., description='Model to use for generation') + promptText: str = Field( + ..., description='Text prompt for the image generation', max_length=1000 + ) + ratio: RunwayTextToImageAspectRatioEnum + referenceImages: Optional[List[ReferenceImage]] = Field( + None, description='Array of reference images to guide the generation' + ) + + +class RunwayTextToImageResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability.py similarity index 100% rename from comfy_api_nodes/apis/stability_api.py rename to comfy_api_nodes/apis/stability.py diff --git a/comfy_api_nodes/apis/topaz_api.py b/comfy_api_nodes/apis/topaz.py similarity index 97% rename from comfy_api_nodes/apis/topaz_api.py rename to comfy_api_nodes/apis/topaz.py index 4d9e62e72..a9e6235a7 100644 --- a/comfy_api_nodes/apis/topaz_api.py +++ b/comfy_api_nodes/apis/topaz.py @@ -41,7 +41,7 @@ class Resolution(BaseModel): height: int = Field(...) -class CreateCreateVideoRequestSource(BaseModel): +class CreateVideoRequestSource(BaseModel): container: str = Field(...) size: int = Field(..., description="Size of the video file in bytes") duration: int = Field(..., description="Duration of the video file in seconds") @@ -89,7 +89,7 @@ class Overrides(BaseModel): class CreateVideoRequest(BaseModel): - source: CreateCreateVideoRequestSource = Field(...) + source: CreateVideoRequestSource = Field(...) filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) output: OutputInformationVideo = Field(...) overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo.py similarity index 100% rename from comfy_api_nodes/apis/tripo_api.py rename to comfy_api_nodes/apis/tripo.py diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo.py similarity index 100% rename from comfy_api_nodes/apis/veo_api.py rename to comfy_api_nodes/apis/veo.py diff --git a/comfy_api_nodes/apis/vidu.py b/comfy_api_nodes/apis/vidu.py new file mode 100644 index 000000000..469adcdbc --- /dev/null +++ b/comfy_api_nodes/apis/vidu.py @@ -0,0 +1,65 @@ +from pydantic import BaseModel, Field + + +class SubjectReference(BaseModel): + id: str = Field(...) + images: list[str] = Field(...) + + +class FrameSetting(BaseModel): + prompt: str = Field(...) + key_image: str = Field(...) + duration: int = Field(...) + + +class TaskMultiFrameCreationRequest(BaseModel): + model: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + resolution: str = Field(...) + start_image: str = Field(...) + image_settings: list[FrameSetting] = Field(...) + + +class TaskExtendCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(..., max_length=2000) + duration: int = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + resolution: str = Field(...) + images: list[str] | None = Field(None, description="Base64 encoded string or image URL") + video_url: str = Field(..., description="URL of the video to extend") + + +class TaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(..., max_length=2000) + duration: int = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + aspect_ratio: str | None = Field(None) + resolution: str | None = Field(None) + movement_amplitude: str | None = Field(None) + images: list[str] | None = Field(None, description="Base64 encoded string or image URL") + subjects: list[SubjectReference] | None = Field(None) + bgm: bool | None = Field(None) + audio: bool | None = Field(None) + + +class TaskCreationResponse(BaseModel): + task_id: str = Field(...) + state: str = Field(...) + created_at: str = Field(...) + code: int | None = Field(None, description="Error code") + + +class TaskResult(BaseModel): + id: str = Field(..., description="Creation id") + url: str = Field(..., description="The URL of the generated results, valid for one hour") + cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") + + +class TaskStatusResponse(BaseModel): + state: str = Field(...) + err_code: str | None = Field(None) + progress: float | None = Field(None) + credits: int | None = Field(None) + creations: list[TaskResult] = Field(..., description="Generated results") diff --git a/comfy_api_nodes/apis/wavespeed.py b/comfy_api_nodes/apis/wavespeed.py new file mode 100644 index 000000000..07a7bfa5d --- /dev/null +++ b/comfy_api_nodes/apis/wavespeed.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, Field + + +class SeedVR2ImageRequest(BaseModel): + image: str = Field(...) + target_resolution: str = Field(...) + output_format: str = Field("png") + enable_sync_mode: bool = Field(False) + + +class FlashVSRRequest(BaseModel): + target_resolution: str = Field(...) + video: str = Field(...) + duration: float = Field(...) + + +class TaskCreatedDataResponse(BaseModel): + id: str = Field(...) + + +class TaskCreatedResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskCreatedDataResponse | None = Field(None) + + +class TaskResultDataResponse(BaseModel): + status: str = Field(...) + outputs: list[str] = Field([]) + + +class TaskResultResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskResultDataResponse | None = Field(None) diff --git a/comfy_api_nodes/canary.py b/comfy_api_nodes/canary.py deleted file mode 100644 index 4df7590b6..000000000 --- a/comfy_api_nodes/canary.py +++ /dev/null @@ -1,10 +0,0 @@ -import av - -ver = av.__version__.split(".") -if int(ver[0]) < 14: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -if int(ver[0]) == 14 and int(ver[1]) < 2: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -NODE_CLASS_MAPPINGS = {} diff --git a/comfy_api_nodes/mapper_utils.py b/comfy_api_nodes/mapper_utils.py deleted file mode 100644 index 6fab8f4bb..000000000 --- a/comfy_api_nodes/mapper_utils.py +++ /dev/null @@ -1,116 +0,0 @@ -from enum import Enum - -from pydantic.fields import FieldInfo -from pydantic import BaseModel -from pydantic_core import PydanticUndefined - -from comfy.comfy_types.node_typing import IO, InputTypeOptions - -NodeInput = tuple[IO, InputTypeOptions] - - -def _create_base_config(field_info: FieldInfo) -> InputTypeOptions: - config = {} - if hasattr(field_info, "default") and field_info.default is not PydanticUndefined: - config["default"] = field_info.default - if hasattr(field_info, "description") and field_info.description is not None: - config["tooltip"] = field_info.description - return config - - -def _get_number_constraints_config(field_info: FieldInfo) -> dict: - config = {} - if hasattr(field_info, "metadata"): - metadata = field_info.metadata - for constraint in metadata: - if hasattr(constraint, "ge"): - config["min"] = constraint.ge - if hasattr(constraint, "le"): - config["max"] = constraint.le - if hasattr(constraint, "multiple_of"): - config["step"] = constraint.multiple_of - return config - - -def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.IMAGE, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.STRING, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.FLOAT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.INT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_combo_input( - field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs -) -> NodeInput: - combo_config = {} - if enum_type is not None: - combo_config["options"] = [option.value for option in enum_type] - combo_config = { - **combo_config, - **_create_base_config(field_info), - **kwargs, - } - return IO.COMBO, combo_config - - -def model_field_to_node_input( - input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs -) -> NodeInput: - """ - Maps a field from a Pydantic model to a Comfy node input. - - Args: - input_type: The type of the input. - base_model: The Pydantic model to map the field from. - field_name: The name of the field to map. - **kwargs: Additional key/values to include in the input options. - - Note: - For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically. - - Example: - >>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True) - >>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum) - >>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True) - """ - field_info: FieldInfo = base_model.model_fields[field_name] - result: NodeInput - - if input_type == IO.IMAGE: - result = _model_field_to_image_input(field_info, **kwargs) - elif input_type == IO.STRING: - result = _model_field_to_string_input(field_info, **kwargs) - elif input_type == IO.FLOAT: - result = _model_field_to_float_input(field_info, **kwargs) - elif input_type == IO.INT: - result = _model_field_to_int_input(field_info, **kwargs) - elif input_type == IO.COMBO: - result = _model_field_to_combo_input(field_info, **kwargs) - else: - message = f"Invalid input type: {input_type}" - raise ValueError(message) - - return result diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index ce077d6b3..61c3b4503 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -3,7 +3,7 @@ from pydantic import BaseModel from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.bfl_api import ( +from comfy_api_nodes.apis.bfl import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, @@ -97,6 +97,9 @@ class FluxProUltraImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.06}""", + ), ) @classmethod @@ -352,6 +355,9 @@ class FluxProExpandNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.05}""", + ), ) @classmethod @@ -458,6 +464,9 @@ class FluxProFillNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.05}""", + ), ) @classmethod @@ -511,6 +520,21 @@ class Flux2ProImageNode(IO.ComfyNode): NODE_ID = "Flux2ProImageNode" DISPLAY_NAME = "Flux.2 [pro] Image" API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate" + PRICE_BADGE_EXPR = """ + ( + $MP := 1024 * 1024; + $outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]); + $outputCost := 0.03 + 0.015 * ($outMP - 1); + inputs.images.connected + ? { + "type":"range_usd", + "min_usd": $outputCost + 0.015, + "max_usd": $outputCost + 0.12, + "format": { "approximate": true } + } + : {"type":"usd","usd": $outputCost} + ) + """ @classmethod def define_schema(cls) -> IO.Schema: @@ -563,6 +587,10 @@ class Flux2ProImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]), + expr=cls.PRICE_BADGE_EXPR, + ), ) @classmethod @@ -623,6 +651,22 @@ class Flux2MaxImageNode(Flux2ProImageNode): NODE_ID = "Flux2MaxImageNode" DISPLAY_NAME = "Flux.2 [max] Image" API_ENDPOINT = "/proxy/bfl/flux-2-max/generate" + PRICE_BADGE_EXPR = """ + ( + $MP := 1024 * 1024; + $outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]); + $outputCost := 0.07 + 0.03 * ($outMP - 1); + + inputs.images.connected + ? { + "type":"range_usd", + "min_usd": $outputCost + 0.03, + "max_usd": $outputCost + 0.24, + "format": { "approximate": true } + } + : {"type":"usd","usd": $outputCost} + ) + """ class BFLExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py new file mode 100644 index 000000000..d3a52bc1b --- /dev/null +++ b/comfy_api_nodes/nodes_bria.py @@ -0,0 +1,198 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bria import ( + BriaEditImageRequest, + BriaResponse, + BriaStatusResponse, + InputModerationSettings, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + convert_mask_to_image, + download_url_to_image_tensor, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, +) + + +class BriaImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaImageEditNode", + display_name="Bria FIBO Image Edit", + category="api node/image/Bria", + description="Edit images using Bria latest model", + inputs=[ + IO.Combo.Input("model", options=["FIBO"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Instruction to edit image", + ), + IO.String.Input("negative_prompt", multiline=True, default=""), + IO.String.Input( + "structured_prompt", + multiline=True, + default="", + tooltip="A string containing the structured edit prompt in JSON format. " + "Use this instead of usual prompt for precise, programmatic control.", + ), + IO.Int.Input( + "seed", + default=1, + min=1, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Float.Input( + "guidance_scale", + default=3, + min=3, + max=5, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely.", + ), + IO.Int.Input( + "steps", + default=50, + min=20, + max=50, + step=1, + display_mode=IO.NumberDisplay.number, + ), + IO.DynamicCombo.Input( + "moderation", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "prompt_content_moderation", default=False + ), + IO.Boolean.Input( + "visual_input_moderation", default=False + ), + IO.Boolean.Input( + "visual_output_moderation", default=True + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Moderation settings", + ), + IO.Mask.Input( + "mask", + tooltip="If omitted, the edit applies to the entire image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(display_name="structured_prompt"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.04}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + negative_prompt: str, + structured_prompt: str, + seed: int, + guidance_scale: float, + steps: int, + moderation: InputModerationSettings, + mask: Input.Image | None = None, + ) -> IO.NodeOutput: + if not prompt and not structured_prompt: + raise ValueError( + "One of prompt or structured_prompt is required to be non-empty." + ) + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + mask_url = None + if mask is not None: + mask_url = ( + await upload_images_to_comfyapi( + cls, + convert_mask_to_image(mask), + max_images=1, + mime_type="image/png", + wait_label="Uploading mask", + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"), + data=BriaEditImageRequest( + instruction=prompt if prompt else None, + structured_instruction=structured_prompt if structured_prompt else None, + images=await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type="image/png", + wait_label="Uploading image", + ), + mask=mask_url, + negative_prompt=negative_prompt if negative_prompt else None, + guidance_scale=guidance_scale, + seed=seed, + model_version=model, + steps_num=steps, + prompt_content_moderation=moderation.get( + "prompt_content_moderation", False + ), + visual_input_content_moderation=moderation.get( + "visual_input_moderation", False + ), + visual_output_content_moderation=moderation.get( + "visual_output_moderation", False + ), + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaResponse, + ) + return IO.NodeOutput( + await download_url_to_image_tensor(response.result.image_url), + response.result.structured_prompt, + ) + + +class BriaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + BriaImageEditNode, + ] + + +async def comfy_entrypoint() -> BriaExtension: + return BriaExtension() diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index d4a2cfae6..0cb5e3be8 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -5,11 +5,10 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.bytedance_api import ( +from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, VIDEO_TASKS_EXECUTION_TIME, - Image2ImageTaskCreationRequest, Image2VideoTaskCreationRequest, ImageTaskCreationResponse, Seedream4Options, @@ -126,6 +125,9 @@ class ByteDanceImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -171,99 +173,6 @@ class ByteDanceImageNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) -class ByteDanceImageEditNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ByteDanceImageEditNode", - display_name="ByteDance Image Edit", - category="api node/image/ByteDance", - description="Edit images using ByteDance models via api based on prompt", - inputs=[ - IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]), - IO.Image.Input( - "image", - tooltip="The base image to edit", - ), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Instruction to edit image", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=2147483647, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Seed to use for generation", - optional=True, - ), - IO.Float.Input( - "guidance_scale", - default=5.5, - min=1.0, - max=10.0, - step=0.01, - display_mode=IO.NumberDisplay.number, - tooltip="Higher value makes the image follow the prompt more closely", - optional=True, - ), - IO.Boolean.Input( - "watermark", - default=False, - tooltip='Whether to add an "AI generated" watermark to the image', - optional=True, - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - is_deprecated=True, - ) - - @classmethod - async def execute( - cls, - model: str, - image: Input.Image, - prompt: str, - seed: int, - guidance_scale: float, - watermark: bool, - ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=True, min_length=1) - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio(image, (1, 3), (3, 1)) - source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] - payload = Image2ImageTaskCreationRequest( - model=model, - prompt=prompt, - image=source_url, - seed=seed, - guidance_scale=guidance_scale, - watermark=watermark, - ) - response = await sync_op( - cls, - ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), - data=payload, - response_model=ImageTaskCreationResponse, - ) - return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) - - class ByteDanceSeedreamNode(IO.ComfyNode): @classmethod @@ -367,6 +276,19 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $price := $contains(widgets.model, "seedream-4-5-251128") ? 0.04 : 0.03; + { + "type":"usd", + "usd": $price, + "format": { "suffix":" x images/Run", "approximate": true } + } + ) + """, + ), ) @classmethod @@ -461,7 +383,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-t2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -512,6 +439,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -522,6 +455,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -535,7 +469,10 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) @@ -550,7 +487,11 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): ) return await process_video_task( cls, - payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), + payload=Text2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt)], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, + ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -567,7 +508,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-i2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -622,6 +568,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -632,6 +584,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -646,7 +599,10 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) @@ -668,6 +624,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -685,7 +642,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + options=["seedance-1-5-pro-251215", "seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( @@ -744,6 +701,12 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -754,6 +717,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -769,7 +733,10 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) for i in (first_frame, last_frame): @@ -802,6 +769,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[0])), role="first_frame"), TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -877,6 +845,41 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $priceByModel := { + "seedance-1-0-pro": { + "480p":[0.23,0.24], + "720p":[0.51,0.56] + }, + "seedance-1-0-lite": { + "480p":[0.17,0.18], + "720p":[0.37,0.41] + } + }; + $model := widgets.model; + $modelKey := + $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : + "seedance-1-0-lite"; + $resolution := widgets.resolution; + $resKey := + $contains($resolution, "720") ? "720p" : + "480p"; + $modelPrices := $lookup($priceByModel, $modelKey); + $baseRange := $lookup($modelPrices, $resKey); + $min10s := $baseRange[0]; + $max10s := $baseRange[1]; + $scale := widgets.duration / 10; + $minCost := $min10s * $scale; + $maxCost := $max10s * $scale; + ($minCost = $maxCost) + ? {"type":"usd","usd": $minCost} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ) + """, + ), ) @classmethod @@ -946,12 +949,64 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None: ) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution", "generate_audio"]), + expr=""" + ( + $priceByModel := { + "seedance-1-5-pro": { + "480p":[0.12,0.12], + "720p":[0.26,0.26], + "1080p":[0.58,0.59] + }, + "seedance-1-0-pro": { + "480p":[0.23,0.24], + "720p":[0.51,0.56], + "1080p":[1.18,1.22] + }, + "seedance-1-0-pro-fast": { + "480p":[0.09,0.1], + "720p":[0.21,0.23], + "1080p":[0.47,0.49] + }, + "seedance-1-0-lite": { + "480p":[0.17,0.18], + "720p":[0.37,0.41], + "1080p":[0.85,0.88] + } + }; + $model := widgets.model; + $modelKey := + $contains($model, "seedance-1-5-pro") ? "seedance-1-5-pro" : + $contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" : + $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : + "seedance-1-0-lite"; + $resolution := widgets.resolution; + $resKey := + $contains($resolution, "1080") ? "1080p" : + $contains($resolution, "720") ? "720p" : + "480p"; + $modelPrices := $lookup($priceByModel, $modelKey); + $baseRange := $lookup($modelPrices, $resKey); + $min10s := $baseRange[0]; + $max10s := $baseRange[1]; + $scale := widgets.duration / 10; + $audioMultiplier := ($modelKey = "seedance-1-5-pro" and widgets.generate_audio) ? 2 : 1; + $minCost := $min10s * $scale * $audioMultiplier; + $maxCost := $max10s * $scale * $audioMultiplier; + ($minCost = $maxCost) + ? {"type":"usd","usd": $minCost, "format": { "approximate": true }} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost, "format": { "approximate": true }} + ) + """, +) + + class ByteDanceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ ByteDanceImageNode, - ByteDanceImageEditNode, ByteDanceSeedreamNode, ByteDanceTextToVideoNode, ByteDanceImageToVideoNode, diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index e8ed7e797..3b31caa7b 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -14,7 +14,7 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input, Types -from comfy_api_nodes.apis.gemini_api import ( +from comfy_api_nodes.apis.gemini import ( GeminiContent, GeminiFileData, GeminiGenerateContentRequest, @@ -130,7 +130,7 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera Returns: List of response parts matching the requested type. """ - if response.candidates is None: + if not response.candidates: if response.promptFeedback and response.promptFeedback.blockReason: feedback = response.promptFeedback raise ValueError( @@ -141,14 +141,24 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." ) parts = [] - for part in response.candidates[0].content.parts: - if part_type == "text" and part.text: - parts.append(part) - elif part.inlineData and part.inlineData.mimeType == part_type: - parts.append(part) - elif part.fileData and part.fileData.mimeType == part_type: - parts.append(part) - # Skip parts that don't match the requested type + blocked_reasons = [] + for candidate in response.candidates: + if candidate.finishReason and candidate.finishReason.upper() == "IMAGE_PROHIBITED_CONTENT": + blocked_reasons.append(candidate.finishReason) + continue + if candidate.content is None or candidate.content.parts is None: + continue + for part in candidate.content.parts: + if part_type == "text" and part.text: + parts.append(part) + elif part.inlineData and part.inlineData.mimeType == part_type: + parts.append(part) + elif part.fileData and part.fileData.mimeType == part_type: + parts.append(part) + + if not parts and blocked_reasons: + raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}") + return parts @@ -309,6 +319,30 @@ class GeminiNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "gemini-2.5-flash") ? { + "type": "list_usd", + "usd": [0.0003, 0.0025], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens"} + } + : $contains($m, "gemini-2.5-pro") ? { + "type": "list_usd", + "usd": [0.00125, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gemini-3-pro-preview") ? { + "type": "list_usd", + "usd": [0.002, 0.012], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : {"type":"text", "text":"Token-based"} + ) + """, + ), ) @classmethod @@ -570,6 +604,9 @@ class GeminiImage(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.039,"format":{"suffix":"/Image (1K)","approximate":true}}""", + ), ) @classmethod @@ -700,6 +737,19 @@ class GeminiImage2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $r := widgets.resolution; + ($contains($r,"1k") or $contains($r,"2k")) + ? {"type":"usd","usd":0.134,"format":{"suffix":"/Image","approximate":true}} + : $contains($r,"4k") + ? {"type":"usd","usd":0.24,"format":{"suffix":"/Image","approximate":true}} + : {"type":"text","text":"Token-based"} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py new file mode 100644 index 000000000..da15e97ea --- /dev/null +++ b/comfy_api_nodes/nodes_grok.py @@ -0,0 +1,417 @@ +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.grok import ( + ImageEditRequest, + ImageGenerationRequest, + ImageGenerationResponse, + InputUrlObject, + VideoEditRequest, + VideoGenerationRequest, + VideoGenerationResponse, + VideoStatusResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_fs_object_size, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + upload_video_to_comfyapi, + validate_string, + validate_video_duration, +) + + +class GrokImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokImageNode", + display_name="Grok Image", + category="api node/image/Grok", + description="Generate images using Grok based on a text prompt", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + ), + IO.Int.Input( + "number_of_images", + default=1, + min=1, + max=10, + step=1, + tooltip="Number of images to generate", + display_mode=IO.NumberDisplay.number, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), + expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + aspect_ratio: str, + number_of_images: int, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/images/generations", method="POST"), + data=ImageGenerationRequest( + model=model, + prompt=prompt, + aspect_ratio=aspect_ratio, + n=number_of_images, + seed=seed, + ), + response_model=ImageGenerationResponse, + ) + if len(response.data) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) + return IO.NodeOutput( + torch.cat( + [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], + ) + ) + + +class GrokImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokImageEditNode", + display_name="Grok Image Edit", + category="api node/image/Grok", + description="Modify an existing image based on a text prompt", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-image-beta"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + IO.Combo.Input("resolution", options=["1K"]), + IO.Int.Input( + "number_of_images", + default=1, + min=1, + max=10, + step=1, + tooltip="Number of edited images to generate", + display_mode=IO.NumberDisplay.number, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]), + expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + resolution: str, + number_of_images: int, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), + data=ImageEditRequest( + model=model, + image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"), + prompt=prompt, + resolution=resolution.lower(), + n=number_of_images, + seed=seed, + ), + response_model=ImageGenerationResponse, + ) + if len(response.data) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) + return IO.NodeOutput( + torch.cat( + [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], + ) + ) + + +class GrokVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoNode", + display_name="Grok Video", + category="api node/video/Grok", + description="Generate video from a prompt or an image", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=6, + min=1, + max=15, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Image.Input("image", optional=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]), + expr=""" + ( + $base := 0.181 * widgets.duration; + {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + image: Input.Image | None = None, + ) -> IO.NodeOutput: + image_url = None + if image is not None: + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}") + validate_string(prompt, strip_whitespace=True, min_length=1) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), + data=VideoGenerationRequest( + model=model, + image=image_url, + prompt=prompt, + resolution=resolution, + duration=duration, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokVideoEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoEditNode", + display_name="Grok Video Edit", + category="api node/video/Grok", + description="Edit an existing video based on a text prompt.", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-video-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.Video.Input("video", tooltip="Maximum supported duration is 8.7 seconds and 50MB file size."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + video: Input.Video, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + validate_video_duration(video, min_duration=1, max_duration=8.7) + video_stream = video.get_stream_source() + video_size = get_fs_object_size(video_stream) + if video_size > 50 * 1024 * 1024: + raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/edits", method="POST"), + data=VideoEditRequest( + model=model, + video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), + prompt=prompt, + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + GrokImageNode, + GrokImageEditNode, + GrokVideoNode, + GrokVideoEditNode, + ] + + +async def comfy_entrypoint() -> GrokExtension: + return GrokExtension() diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py new file mode 100644 index 000000000..b3a736643 --- /dev/null +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -0,0 +1,297 @@ +import os + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.hunyuan3d import ( + Hunyuan3DViewImage, + InputGenerateType, + ResultFile3D, + To3DProTaskCreateResponse, + To3DProTaskQueryRequest, + To3DProTaskRequest, + To3DProTaskResultResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_bytesio, + downscale_image_tensor_by_max_side, + poll_op, + sync_op, + upload_image_to_comfyapi, + validate_image_dimensions, + validate_string, +) +from folder_paths import get_output_directory + + +def get_glb_obj_from_response(response_objs: list[ResultFile3D]) -> ResultFile3D: + for i in response_objs: + if i.Type.lower() == "glb": + return i + raise ValueError("No GLB file found in response. Please report this to the developers.") + + +class TencentTextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentTextToModelNode", + display_name="Hunyuan3D: Text to Model (Pro)", + category="api node/3d/Tencent", + inputs=[ + IO.Combo.Input( + "model", + options=["3.0", "3.1"], + tooltip="The LowPoly option is unavailable for the `3.1` model.", + ), + IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."), + IO.Int.Input("face_count", default=500000, min=40000, max=1500000), + IO.DynamicCombo.Input( + "generate_type", + options=[ + IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]), + IO.DynamicCombo.Option( + "LowPoly", + [ + IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]), + IO.Boolean.Input("pbr", default=False), + ], + ), + IO.DynamicCombo.Option("Geometry", []), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["generate_type", "generate_type.pbr", "face_count"]), + expr=""" + ( + $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; + $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; + $face := widgets.face_count != 500000 ? 10 : 0; + {"type":"usd","usd": ($base + $pbr + $face) * 0.02} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + face_count: int, + generate_type: InputGenerateType, + seed: int, + ) -> IO.NodeOutput: + _ = seed + validate_string(prompt, field_name="prompt", min_length=1, max_length=1024) + if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly": + raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DProTaskRequest( + Model=model, + Prompt=prompt, + FaceCount=face_count, + GenerateType=generate_type["generate_type"], + EnablePBR=generate_type.get("pbr", None), + PolygonType=generate_type.get("polygon_type", None), + ), + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + model_file = f"hunyuan_model_{response.JobId}.glb" + await download_url_to_bytesio( + get_glb_obj_from_response(result.ResultFile3Ds).Url, + os.path.join(get_output_directory(), model_file), + ) + return IO.NodeOutput(model_file) + + +class TencentImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentImageToModelNode", + display_name="Hunyuan3D: Image(s) to Model (Pro)", + category="api node/3d/Tencent", + inputs=[ + IO.Combo.Input( + "model", + options=["3.0", "3.1"], + tooltip="The LowPoly option is unavailable for the `3.1` model.", + ), + IO.Image.Input("image"), + IO.Image.Input("image_left", optional=True), + IO.Image.Input("image_right", optional=True), + IO.Image.Input("image_back", optional=True), + IO.Int.Input("face_count", default=500000, min=40000, max=1500000), + IO.DynamicCombo.Input( + "generate_type", + options=[ + IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]), + IO.DynamicCombo.Option( + "LowPoly", + [ + IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]), + IO.Boolean.Input("pbr", default=False), + ], + ), + IO.DynamicCombo.Option("Geometry", []), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["generate_type", "generate_type.pbr", "face_count"], + inputs=["image_left", "image_right", "image_back"], + ), + expr=""" + ( + $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; + $multiview := ( + inputs.image_left.connected or inputs.image_right.connected or inputs.image_back.connected + ) ? 10 : 0; + $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; + $face := widgets.face_count != 500000 ? 10 : 0; + {"type":"usd","usd": ($base + $multiview + $pbr + $face) * 0.02} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + face_count: int, + generate_type: InputGenerateType, + seed: int, + image_left: Input.Image | None = None, + image_right: Input.Image | None = None, + image_back: Input.Image | None = None, + ) -> IO.NodeOutput: + _ = seed + if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly": + raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.") + validate_image_dimensions(image, min_width=128, min_height=128) + multiview_images = [] + for k, v in { + "left": image_left, + "right": image_right, + "back": image_back, + }.items(): + if v is None: + continue + validate_image_dimensions(v, min_width=128, min_height=128) + multiview_images.append( + Hunyuan3DViewImage( + ViewType=k, + ViewImageUrl=await upload_image_to_comfyapi( + cls, + downscale_image_tensor_by_max_side(v, max_side=4900), + mime_type="image/webp", + total_pixels=24_010_000, + ), + ) + ) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DProTaskRequest( + Model=model, + FaceCount=face_count, + GenerateType=generate_type["generate_type"], + ImageUrl=await upload_image_to_comfyapi( + cls, + downscale_image_tensor_by_max_side(image, max_side=4900), + mime_type="image/webp", + total_pixels=24_010_000, + ), + MultiViewImages=multiview_images if multiview_images else None, + EnablePBR=generate_type.get("pbr", None), + PolygonType=generate_type.get("polygon_type", None), + ), + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + model_file = f"hunyuan_model_{response.JobId}.glb" + await download_url_to_bytesio( + get_glb_obj_from_response(result.ResultFile3Ds).Url, + os.path.join(get_output_directory(), model_file), + ) + return IO.NodeOutput(model_file) + + +class TencentHunyuan3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TencentTextToModelNode, + TencentImageToModelNode, + ] + + +async def comfy_entrypoint() -> TencentHunyuan3DExtension: + return TencentHunyuan3DExtension() diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 48f94e612..feaf7a858 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -4,7 +4,7 @@ from comfy_api.latest import IO, ComfyExtension from PIL import Image import numpy as np import torch -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.ideogram import ( IdeogramGenerateRequest, IdeogramGenerateResponse, ImageRequest, @@ -236,7 +236,6 @@ class IdeogramV1(IO.ComfyNode): display_name="Ideogram V1", category="api node/image/Ideogram", description="Generates images using the Ideogram V1 model.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -298,6 +297,17 @@ class IdeogramV1(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), + expr=""" + ( + $n := widgets.num_images; + $base := (widgets.turbo = true) ? 0.0286 : 0.0858; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod @@ -351,7 +361,6 @@ class IdeogramV2(IO.ComfyNode): display_name="Ideogram V2", category="api node/image/Ideogram", description="Generates images using the Ideogram V2 model.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -436,6 +445,17 @@ class IdeogramV2(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), + expr=""" + ( + $n := widgets.num_images; + $base := (widgets.turbo = true) ? 0.0715 : 0.1144; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod @@ -506,7 +526,6 @@ class IdeogramV3(IO.ComfyNode): category="api node/image/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -591,6 +610,23 @@ class IdeogramV3(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed", "num_images"], inputs=["character_image"]), + expr=""" + ( + $n := widgets.num_images; + $speed := widgets.rendering_speed; + $hasChar := inputs.character_image.connected; + $base := + $contains($speed,"quality") ? ($hasChar ? 0.286 : 0.1287) : + $contains($speed,"default") ? ($hasChar ? 0.2145 : 0.0858) : + $contains($speed,"turbo") ? ($hasChar ? 0.143 : 0.0429) : + 0.0858; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9c707a339..739fe1855 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -49,7 +49,7 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.kling_api import ( +from comfy_api_nodes.apis.kling import ( ImageToVideoWithAudioRequest, MotionControlRequest, OmniImageParamImage, @@ -249,7 +249,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), - max_poll_attempts=160, ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -567,7 +566,7 @@ async def execute_lipsync( # Upload the audio file to Comfy API and get download URL if audio: audio_url = await upload_audio_to_comfyapi( - cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3" + cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg" ) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: @@ -764,6 +763,33 @@ class KlingTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $m := widgets.mode; + $contains($m,"v2-5-turbo") + ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : $contains($m,"v2-1-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v2-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v1-6") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($m,"v1") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -818,6 +844,16 @@ class OmniProTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -886,6 +922,16 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -981,6 +1027,16 @@ class OmniProImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -1056,6 +1112,16 @@ class OmniProVideoToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.126, "pro": 0.168}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -1142,6 +1208,16 @@ class OmniProEditVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.126, "pro": 0.168}; + {"type":"usd","usd": $lookup($rates, $mode), "format":{"suffix":"/second"}} + ) + """, + ), ) @classmethod @@ -1228,6 +1304,9 @@ class OmniProImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.028}""", + ), ) @classmethod @@ -1313,6 +1392,9 @@ class KlingCameraControlT2VNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14}""", + ), ) @classmethod @@ -1375,6 +1457,33 @@ class KlingImage2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), + expr=""" + ( + $mode := widgets.mode; + $model := widgets.model_name; + $dur := widgets.duration; + $contains($model,"v2-5-turbo") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : ($contains($model,"v2-1-master") or $contains($model,"v2-master")) + ? ($contains($dur,"10") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : ($contains($model,"v2-1") or $contains($model,"v1-6") or $contains($model,"v1-5")) + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($model,"v1") + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1448,6 +1557,9 @@ class KlingCameraControlI2VNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.49}""", + ), ) @classmethod @@ -1518,6 +1630,33 @@ class KlingStartEndFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $m := widgets.mode; + $contains($m,"v2-5-turbo") + ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : $contains($m,"v2-1") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : $contains($m,"v2-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v1-6") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($m,"v1") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1583,6 +1722,9 @@ class KlingVideoExtendNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.28}""", + ), ) @classmethod @@ -1664,6 +1806,29 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), + expr=""" + ( + $mode := widgets.mode; + $model := widgets.model_name; + $dur := widgets.duration; + ($contains($model,"v1-6") or $contains($model,"v1-5")) + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($model,"v1") + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1728,6 +1893,16 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["effect_scene"]), + expr=""" + ( + ($contains(widgets.effect_scene,"dizzydizzy") or $contains(widgets.effect_scene,"bloombloom")) + ? {"type":"usd","usd":0.49} + : {"type":"usd","usd":0.28} + ) + """, + ), ) @classmethod @@ -1782,6 +1957,9 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", + ), ) @classmethod @@ -1842,6 +2020,9 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", + ), ) @classmethod @@ -1892,6 +2073,9 @@ class KlingVirtualTryOnNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.7}""", + ), ) @classmethod @@ -1991,6 +2175,19 @@ class KlingImageGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model_name", "n"], inputs=["image"]), + expr=""" + ( + $m := widgets.model_name; + $base := + $contains($m,"kling-v1-5") + ? (inputs.image.connected ? 0.028 : 0.014) + : ($contains($m,"kling-v1") ? 0.0035 : 0.014); + {"type":"usd","usd": $base * widgets.n} + ) + """, + ), ) @classmethod @@ -2074,6 +2271,10 @@ class TextToVideoWithAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), + expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", + ), ) @classmethod @@ -2138,6 +2339,10 @@ class ImageToVideoWithAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), + expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", + ), ) @classmethod @@ -2218,6 +2423,15 @@ class MotionControl(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $prices := {"std": 0.07, "pro": 0.112}; + {"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 7e61560dc..c6424af92 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -28,6 +28,22 @@ class ExecuteTaskRequest(BaseModel): image_uri: str | None = Field(None) +PRICE_BADGE = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $prices := { + "ltx-2 (pro)": {"1920x1080":0.06,"2560x1440":0.12,"3840x2160":0.24}, + "ltx-2 (fast)": {"1920x1080":0.04,"2560x1440":0.08,"3840x2160":0.16} + }; + $modelPrices := $lookup($prices, $lowercase(widgets.model)); + $pps := $lookup($modelPrices, widgets.resolution); + {"type":"usd","usd": $pps * widgets.duration} + ) + """, +) + + class TextToVideoNode(IO.ComfyNode): @classmethod def define_schema(cls): @@ -69,6 +85,7 @@ class TextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE, ) @classmethod @@ -145,6 +162,7 @@ class ImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE, ) @classmethod diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 894f2b08c..9ed6cd299 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.luma_api import ( +from comfy_api_nodes.apis.luma import ( LumaAspectRatio, LumaCharacterRef, LumaConceptChain, @@ -189,6 +189,19 @@ class LumaImageGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m,"photon-flash-1") + ? {"type":"usd","usd":0.0027} + : $contains($m,"photon-1") + ? {"type":"usd","usd":0.0104} + : {"type":"usd","usd":0.0246} + ) + """, + ), ) @classmethod @@ -303,6 +316,19 @@ class LumaImageModifyNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m,"photon-flash-1") + ? {"type":"usd","usd":0.0027} + : $contains($m,"photon-1") + ? {"type":"usd","usd":0.0104} + : {"type":"usd","usd":0.0246} + ) + """, + ), ) @classmethod @@ -395,6 +421,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -505,6 +532,8 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, + ) @classmethod @@ -568,6 +597,53 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): return LumaKeyframes(frame0=frame0, frame1=frame1) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution", "duration"]), + expr=""" + ( + $p := { + "ray-flash-2": { + "5s": {"4k":3.13,"1080p":0.79,"720p":0.34,"540p":0.2}, + "9s": {"4k":5.65,"1080p":1.42,"720p":0.61,"540p":0.36} + }, + "ray-2": { + "5s": {"4k":9.11,"1080p":2.27,"720p":1.02,"540p":0.57}, + "9s": {"4k":16.4,"1080p":4.1,"720p":1.83,"540p":1.03} + } + }; + + $m := widgets.model; + $d := widgets.duration; + $r := widgets.resolution; + + $modelKey := + $contains($m,"ray-flash-2") ? "ray-flash-2" : + $contains($m,"ray-2") ? "ray-2" : + $contains($m,"ray-1-6") ? "ray-1-6" : + "other"; + + $durKey := $contains($d,"5s") ? "5s" : $contains($d,"9s") ? "9s" : ""; + $resKey := + $contains($r,"4k") ? "4k" : + $contains($r,"1080p") ? "1080p" : + $contains($r,"720p") ? "720p" : + $contains($r,"540p") ? "540p" : ""; + + $modelPrices := $lookup($p, $modelKey); + $durPrices := $lookup($modelPrices, $durKey); + $v := $lookup($durPrices, $resKey); + + $price := + ($modelKey = "ray-1-6") ? 0.5 : + ($modelKey = "other") ? 0.79 : + ($exists($v) ? $v : 0.79); + + {"type":"usd","usd": $price} + ) + """, +) + + class LumaExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py new file mode 100644 index 000000000..013e71cc8 --- /dev/null +++ b/comfy_api_nodes/nodes_magnific.py @@ -0,0 +1,889 @@ +import math + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.magnific import ( + ImageRelightAdvancedSettingsRequest, + ImageRelightRequest, + ImageSkinEnhancerCreativeRequest, + ImageSkinEnhancerFaithfulRequest, + ImageSkinEnhancerFlexibleRequest, + ImageStyleTransferRequest, + ImageUpscalerCreativeRequest, + ImageUpscalerPrecisionV2Request, + InputAdvancedSettings, + InputPortraitMode, + InputSkinEnhancerMode, + TaskResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + downscale_image_tensor, + get_image_dimensions, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio, + validate_image_dimensions, +) + + +class MagnificImageUpscalerCreativeNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageUpscalerCreativeNode", + display_name="Magnific Image Upscale (Creative)", + category="api node/image/Magnific", + description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. " + "Maximum output: 25.3 megapixels.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", multiline=True, default=""), + IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), + IO.Combo.Input( + "optimized_for", + options=[ + "standard", + "soft_portraits", + "hard_portraits", + "art_n_illustration", + "videogame_assets", + "nature_n_landscapes", + "films_n_photography", + "3d_renders", + "science_fiction_n_horror", + ], + ), + IO.Int.Input("creativity", min=-10, max=10, default=0, display_mode=IO.NumberDisplay.slider), + IO.Int.Input( + "hdr", + min=-10, + max=10, + default=0, + tooltip="The level of definition and detail.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "resemblance", + min=-10, + max=10, + default=0, + tooltip="The level of resemblance to the original image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "fractality", + min=-10, + max=10, + default=0, + tooltip="The strength of the prompt and intricacy per square pixel.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "engine", + options=["automatic", "magnific_illusio", "magnific_sharpy", "magnific_sparkle"], + ), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed maximum pixel limit.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), + expr=""" + ( + $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; + {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + scale_factor: str, + optimized_for: str, + creativity: int, + hdr: int, + resemblance: int, + fractality: int, + engine: str, + auto_downscale: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + max_output_pixels = 25_300_000 + height, width = get_image_dimensions(image) + requested_scale = int(scale_factor.rstrip("x")) + output_pixels = height * width * requested_scale * requested_scale + + if output_pixels > max_output_pixels: + if auto_downscale: + # Find optimal scale factor that doesn't require >2x downscale. + # Server upscales in 2x steps, so aggressive downscaling degrades quality. + input_pixels = width * height + scale = 2 + max_input_pixels = max_output_pixels // 4 + for candidate in [16, 8, 4, 2]: + if candidate > requested_scale: + continue + scale_output_pixels = input_pixels * candidate * candidate + if scale_output_pixels <= max_output_pixels: + scale = candidate + max_input_pixels = None + break + downscale_ratio = math.sqrt(scale_output_pixels / max_output_pixels) + if downscale_ratio <= 2.0: + scale = candidate + max_input_pixels = max_output_pixels // (candidate * candidate) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + scale_factor = f"{scale}x" + else: + raise ValueError( + f"Output size ({width * requested_scale}x{height * requested_scale} = {output_pixels:,} pixels) " + f"exceeds maximum allowed size of {max_output_pixels:,} pixels. " + f"Use a smaller input image or lower scale factor." + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"), + response_model=TaskResponse, + data=ImageUpscalerCreativeRequest( + image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], + scale_factor=scale_factor, + optimized_for=optimized_for, + creativity=creativity, + hdr=hdr, + resemblance=resemblance, + fractality=fractality, + engine=engine, + prompt=prompt if prompt else None, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageUpscalerPreciseV2Node", + display_name="Magnific Image Upscale (Precise V2)", + category="api node/image/Magnific", + description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " + "Maximum output: 10060×10060 pixels.", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), + IO.Combo.Input( + "flavor", + options=["sublime", "photo", "photo_denoiser"], + tooltip="Processing style: " + "sublime for general use, photo for photographs, photo_denoiser for noisy photos.", + ), + IO.Int.Input( + "sharpen", + min=0, + max=100, + default=7, + tooltip="Image sharpness intensity. Higher values increase edge definition and clarity.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "smart_grain", + min=0, + max=100, + default=7, + tooltip="Intelligent grain/texture enhancement to prevent the image from " + "looking too smooth or artificial.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "ultra_detail", + min=0, + max=100, + default=30, + tooltip="Controls fine detail, textures, and micro-details added during upscaling.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed maximum resolution.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), + expr=""" + ( + $max := widgets.scale_factor = "2x" ? 1.326 : 1.657; + {"type": "range_usd", "min_usd": 0.11, "max_usd": $max} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + scale_factor: str, + flavor: str, + sharpen: int, + smart_grain: int, + ultra_detail: int, + auto_downscale: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + max_output_dimension = 10060 + height, width = get_image_dimensions(image) + requested_scale = int(scale_factor.strip("x")) + output_width = width * requested_scale + output_height = height * requested_scale + + if output_width > max_output_dimension or output_height > max_output_dimension: + if auto_downscale: + # Find optimal scale factor that doesn't require >2x downscale. + # Server upscales in 2x steps, so aggressive downscaling degrades quality. + max_dim = max(width, height) + scale = 2 + max_input_dim = max_output_dimension // 2 + scale_ratio = max_input_dim / max_dim + max_input_pixels = int(width * height * scale_ratio * scale_ratio) + for candidate in [16, 8, 4, 2]: + if candidate > requested_scale: + continue + output_dim = max_dim * candidate + if output_dim <= max_output_dimension: + scale = candidate + max_input_pixels = None + break + downscale_ratio = output_dim / max_output_dimension + if downscale_ratio <= 2.0: + scale = candidate + max_input_dim = max_output_dimension // candidate + scale_ratio = max_input_dim / max_dim + max_input_pixels = int(width * height * scale_ratio * scale_ratio) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + requested_scale = scale + else: + raise ValueError( + f"Output dimensions ({output_width}x{output_height}) exceed maximum allowed " + f"resolution of {max_output_dimension}x{max_output_dimension} pixels. " + f"Use a smaller input image or lower scale factor." + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"), + response_model=TaskResponse, + data=ImageUpscalerPrecisionV2Request( + image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], + scale_factor=requested_scale, + flavor=flavor, + sharpen=sharpen, + smart_grain=smart_grain, + ultra_detail=ultra_detail, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageStyleTransferNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageStyleTransferNode", + display_name="Magnific Image Style Transfer", + category="api node/image/Magnific", + description="Transfer the style from a reference image to your input image.", + inputs=[ + IO.Image.Input("image", tooltip="The image to apply style transfer to."), + IO.Image.Input("reference_image", tooltip="The reference image to extract style from."), + IO.String.Input("prompt", multiline=True, default=""), + IO.Int.Input( + "style_strength", + min=0, + max=100, + default=100, + tooltip="Percentage of style strength.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "structure_strength", + min=0, + max=100, + default=50, + tooltip="Maintains the structure of the original image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "flavor", + options=["faithful", "gen_z", "psychedelia", "detaily", "clear", "donotstyle", "donotstyle_sharp"], + tooltip="Style transfer flavor.", + ), + IO.Combo.Input( + "engine", + options=[ + "balanced", + "definio", + "illusio", + "3d_cartoon", + "colorful_anime", + "caricature", + "real", + "super_real", + "softy", + ], + tooltip="Processing engine selection.", + ), + IO.DynamicCombo.Input( + "portrait_mode", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Combo.Input( + "portrait_style", + options=["standard", "pop", "super_pop"], + tooltip="Visual style applied to portrait images.", + ), + IO.Combo.Input( + "portrait_beautifier", + options=["none", "beautify_face", "beautify_face_max"], + tooltip="Facial beautification intensity on portraits.", + ), + ], + ), + ], + tooltip="Enable portrait mode for facial enhancements.", + ), + IO.Boolean.Input( + "fixed_generation", + default=True, + tooltip="When disabled, expect each generation to introduce a degree of randomness, " + "leading to more diverse outcomes.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + reference_image: Input.Image, + prompt: str, + style_strength: int, + structure_strength: int, + flavor: str, + engine: str, + portrait_mode: InputPortraitMode, + fixed_generation: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if get_number_of_images(reference_image) != 1: + raise ValueError("Exactly one reference image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + validate_image_dimensions(reference_image, min_height=160, min_width=160) + + is_portrait = portrait_mode["portrait_mode"] == "enabled" + portrait_style = portrait_mode.get("portrait_style", "standard") + portrait_beautifier = portrait_mode.get("portrait_beautifier", "none") + + uploaded_urls = await upload_images_to_comfyapi(cls, [image, reference_image], max_images=2) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-style-transfer", method="POST"), + response_model=TaskResponse, + data=ImageStyleTransferRequest( + image=uploaded_urls[0], + reference_image=uploaded_urls[1], + prompt=prompt if prompt else None, + style_strength=style_strength, + structure_strength=structure_strength, + is_portrait=is_portrait, + portrait_style=portrait_style if is_portrait else None, + portrait_beautifier=portrait_beautifier if is_portrait and portrait_beautifier != "none" else None, + flavor=flavor, + engine=engine, + fixed_generation=fixed_generation, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-style-transfer/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageRelightNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageRelightNode", + display_name="Magnific Image Relight", + category="api node/image/Magnific", + description="Relight an image with lighting adjustments and optional reference-based light transfer.", + inputs=[ + IO.Image.Input("image", tooltip="The image to relight."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Descriptive guidance for lighting. Supports emphasis notation (1-1.4).", + ), + IO.Int.Input( + "light_transfer_strength", + min=0, + max=100, + default=100, + tooltip="Intensity of light transfer application.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "style", + options=[ + "standard", + "darker_but_realistic", + "clean", + "smooth", + "brighter", + "contrasted_n_hdr", + "just_composition", + ], + tooltip="Stylistic output preference.", + ), + IO.Boolean.Input( + "interpolate_from_original", + default=False, + tooltip="Restricts generation freedom to match original more closely.", + ), + IO.Boolean.Input( + "change_background", + default=True, + tooltip="Modifies background based on prompt/reference.", + ), + IO.Boolean.Input( + "preserve_details", + default=True, + tooltip="Maintains texture and fine details from original.", + ), + IO.DynamicCombo.Input( + "advanced_settings", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Int.Input( + "whites", + min=0, + max=100, + default=50, + tooltip="Adjusts the brightest tones in the image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "blacks", + min=0, + max=100, + default=50, + tooltip="Adjusts the darkest tones in the image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "brightness", + min=0, + max=100, + default=50, + tooltip="Overall brightness adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "contrast", + min=0, + max=100, + default=50, + tooltip="Contrast adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "saturation", + min=0, + max=100, + default=50, + tooltip="Color saturation adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "engine", + options=[ + "automatic", + "balanced", + "cool", + "real", + "illusio", + "fairy", + "colorful_anime", + "hard_transform", + "softy", + ], + tooltip="Processing engine selection.", + ), + IO.Combo.Input( + "transfer_light_a", + options=["automatic", "low", "medium", "normal", "high", "high_on_faces"], + tooltip="The intensity of light transfer.", + ), + IO.Combo.Input( + "transfer_light_b", + options=[ + "automatic", + "composition", + "straight", + "smooth_in", + "smooth_out", + "smooth_both", + "reverse_both", + "soft_in", + "soft_out", + "soft_mid", + # "strong_mid", # Commented out because requests fail when this is set. + "style_shift", + "strong_shift", + ], + tooltip="Also modifies light transfer intensity. " + "Can be combined with the previous control for varied effects.", + ), + IO.Boolean.Input( + "fixed_generation", + default=True, + tooltip="Ensures consistent output with the same settings.", + ), + ], + ), + ], + tooltip="Fine-tuning options for advanced lighting control.", + ), + IO.Image.Input( + "reference_image", + optional=True, + tooltip="Optional reference image to transfer lighting from.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + light_transfer_strength: int, + style: str, + interpolate_from_original: bool, + change_background: bool, + preserve_details: bool, + advanced_settings: InputAdvancedSettings, + reference_image: Input.Image | None = None, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if reference_image is not None and get_number_of_images(reference_image) != 1: + raise ValueError("Exactly one reference image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + if reference_image is not None: + validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(reference_image, min_height=160, min_width=160) + + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] + reference_url = None + if reference_image is not None: + reference_url = (await upload_images_to_comfyapi(cls, reference_image, max_images=1))[0] + + adv_settings = None + if advanced_settings["advanced_settings"] == "enabled": + adv_settings = ImageRelightAdvancedSettingsRequest( + whites=advanced_settings["whites"], + blacks=advanced_settings["blacks"], + brightness=advanced_settings["brightness"], + contrast=advanced_settings["contrast"], + saturation=advanced_settings["saturation"], + engine=advanced_settings["engine"], + transfer_light_a=advanced_settings["transfer_light_a"], + transfer_light_b=advanced_settings["transfer_light_b"], + fixed_generation=advanced_settings["fixed_generation"], + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-relight", method="POST"), + response_model=TaskResponse, + data=ImageRelightRequest( + image=image_url, + prompt=prompt if prompt else None, + transfer_light_from_reference_image=reference_url, + light_transfer_strength=light_transfer_strength, + interpolate_from_original=interpolate_from_original, + change_background=change_background, + style=style, + preserve_details=preserve_details, + advanced_settings=adv_settings, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-relight/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageSkinEnhancerNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageSkinEnhancerNode", + display_name="Magnific Image Skin Enhancer", + category="api node/image/Magnific", + description="Skin enhancement for portraits with multiple processing modes.", + inputs=[ + IO.Image.Input("image", tooltip="The portrait image to enhance."), + IO.Int.Input( + "sharpen", + min=0, + max=100, + default=0, + tooltip="Sharpening intensity level.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "smart_grain", + min=0, + max=100, + default=2, + tooltip="Smart grain intensity level.", + display_mode=IO.NumberDisplay.slider, + ), + IO.DynamicCombo.Input( + "mode", + options=[ + IO.DynamicCombo.Option("creative", []), + IO.DynamicCombo.Option( + "faithful", + [ + IO.Int.Input( + "skin_detail", + min=0, + max=100, + default=80, + tooltip="Skin detail enhancement level.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + IO.DynamicCombo.Option( + "flexible", + [ + IO.Combo.Input( + "optimized_for", + options=[ + "enhance_skin", + "improve_lighting", + "enhance_everything", + "transform_to_real", + "no_make_up", + ], + tooltip="Enhancement optimization target.", + ), + ], + ), + ], + tooltip="Processing mode: creative for artistic enhancement, " + "faithful for preserving original appearance, " + "flexible for targeted optimization.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $rates := {"creative": 0.29, "faithful": 0.37, "flexible": 0.45}; + {"type":"usd","usd": $lookup($rates, widgets.mode)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + sharpen: int, + smart_grain: int, + mode: InputSkinEnhancerMode, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=4096 * 4096))[0] + selected_mode = mode["mode"] + + if selected_mode == "creative": + endpoint = "creative" + data = ImageSkinEnhancerCreativeRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + ) + elif selected_mode == "faithful": + endpoint = "faithful" + data = ImageSkinEnhancerFaithfulRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + skin_detail=mode["skin_detail"], + ) + else: # flexible + endpoint = "flexible" + data = ImageSkinEnhancerFlexibleRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + optimized_for=mode["optimized_for"], + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{endpoint}", method="POST"), + response_model=TaskResponse, + data=data, + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + # MagnificImageUpscalerCreativeNode, + # MagnificImageUpscalerPreciseV2Node, + MagnificImageStyleTransferNode, + MagnificImageRelightNode, + MagnificImageSkinEnhancerNode, + ] + + +async def comfy_entrypoint() -> MagnificExtension: + return MagnificExtension() diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py new file mode 100644 index 000000000..740607983 --- /dev/null +++ b/comfy_api_nodes/nodes_meshy.py @@ -0,0 +1,790 @@ +import os + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.meshy import ( + InputShouldRemesh, + InputShouldTexture, + MeshyAnimationRequest, + MeshyAnimationResult, + MeshyImageToModelRequest, + MeshyModelResult, + MeshyMultiImageToModelRequest, + MeshyRefineTask, + MeshyRiggedResult, + MeshyRiggingRequest, + MeshyTaskResponse, + MeshyTextToModelRequest, + MeshyTextureRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_bytesio, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_string, +) +from folder_paths import get_output_directory + + +class MeshyTextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyTextToModelNode", + display_name="Meshy: Text to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.String.Input("prompt", multiline=True, default=""), + IO.Combo.Input("style", options=["realistic", "sculpture"]), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.8}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + style: str, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, field_name="prompt", min_length=1, max_length=600) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyTextToModelRequest( + prompt=prompt, + art_style=style, + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + should_remesh=should_remesh["should_remesh"] == "true", + symmetry_mode=symmetry_mode, + pose_mode=pose_mode.lower(), + seed=seed, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyRefineNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyRefineNode", + display_name="Meshy: Refine Draft Model", + category="api node/3d/Meshy", + description="Refine a previously created draft model.", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) in addition to the base color. " + "Note: this should be set to false when using Sculpture style, " + "as Sculpture style generates its own set of PBR maps.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' may be used at the same time.", + optional=True, + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + meshy_task_id: str, + enable_pbr: bool, + texture_prompt: str, + texture_image: Input.Image | None = None, + ) -> IO.NodeOutput: + if texture_prompt and texture_image is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + texture_image_url = None + if texture_prompt: + validate_string(texture_prompt, field_name="texture_prompt", max_length=600) + if texture_image is not None: + texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyRefineTask( + preview_task_id=meshy_task_id, + enable_pbr=enable_pbr, + texture_prompt=texture_prompt if texture_prompt else None, + texture_image_url=texture_image_url, + ai_model=model, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyImageToModelNode", + display_name="Meshy: Image to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Image.Input("image"), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.DynamicCombo.Input( + "should_texture", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) " + "in addition to the base color.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' " + "may be used at the same time.", + optional=True, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Determines whether textures are generated. " + "Setting it to false skips the texture phase and returns a mesh without textures.", + ), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), + expr=""" + ( + $prices := {"true": 1.2, "false": 0.8}; + {"type":"usd","usd": $lookup($prices, widgets.should_texture)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + should_texture: InputShouldTexture, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + texture = should_texture["should_texture"] == "true" + texture_image_url = texture_prompt = None + if texture: + if should_texture["texture_prompt"] and should_texture["texture_image"] is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + if should_texture["texture_prompt"]: + validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) + texture_prompt = should_texture["texture_prompt"] + if should_texture["texture_image"] is not None: + texture_image_url = ( + await upload_images_to_comfyapi( + cls, should_texture["texture_image"], wait_label="Uploading texture" + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v1/image-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyImageToModelRequest( + image_url=(await upload_images_to_comfyapi(cls, image, wait_label="Uploading base image"))[0], + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + symmetry_mode=symmetry_mode, + should_remesh=should_remesh["should_remesh"] == "true", + should_texture=texture, + enable_pbr=should_texture.get("enable_pbr", None), + pose_mode=pose_mode.lower(), + texture_prompt=texture_prompt, + texture_image_url=texture_image_url, + seed=seed, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyMultiImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyMultiImageToModelNode", + display_name="Meshy: Multi-Image to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=2, max=4), + ), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.DynamicCombo.Input( + "should_texture", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) " + "in addition to the base color.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' " + "may be used at the same time.", + optional=True, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Determines whether textures are generated. " + "Setting it to false skips the texture phase and returns a mesh without textures.", + ), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), + expr=""" + ( + $prices := {"true": 0.6, "false": 0.2}; + {"type":"usd","usd": $lookup($prices, widgets.should_texture)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + images: IO.Autogrow.Type, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + should_texture: InputShouldTexture, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + texture = should_texture["should_texture"] == "true" + texture_image_url = texture_prompt = None + if texture: + if should_texture["texture_prompt"] and should_texture["texture_image"] is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + if should_texture["texture_prompt"]: + validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) + texture_prompt = should_texture["texture_prompt"] + if should_texture["texture_image"] is not None: + texture_image_url = ( + await upload_images_to_comfyapi( + cls, should_texture["texture_image"], wait_label="Uploading texture" + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v1/multi-image-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyMultiImageToModelRequest( + image_urls=await upload_images_to_comfyapi( + cls, list(images.values()), wait_label="Uploading base images" + ), + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + symmetry_mode=symmetry_mode, + should_remesh=should_remesh["should_remesh"] == "true", + should_texture=texture, + enable_pbr=should_texture.get("enable_pbr", None), + pose_mode=pose_mode.lower(), + texture_prompt=texture_prompt, + texture_image_url=texture_image_url, + seed=seed, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyRigModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyRigModelNode", + display_name="Meshy: Rig Model", + category="api node/3d/Meshy", + description="Provides a rigged character in standard formats. " + "Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, " + "or humanoid assets with unclear limb and body structure.", + inputs=[ + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Float.Input( + "height_meters", + min=0.1, + max=15.0, + default=1.7, + tooltip="The approximate height of the character model in meters. " + "This aids in scaling and rigging accuracy.", + ), + IO.Image.Input( + "texture_image", + tooltip="The model's UV-unwrapped base color texture image.", + optional=True, + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), + ) + + @classmethod + async def execute( + cls, + meshy_task_id: str, + height_meters: float, + texture_image: Input.Image | None = None, + ) -> IO.NodeOutput: + texture_image_url = None + if texture_image is not None: + texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/rigging", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyRiggingRequest( + input_task_id=meshy_task_id, + height_meters=height_meters, + texture_image_url=texture_image_url, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{response.result}"), + response_model=MeshyRiggedResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio( + result.result.rigged_character_glb_url, os.path.join(get_output_directory(), model_file) + ) + return IO.NodeOutput(model_file, response.result) + + +class MeshyAnimateModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyAnimateModelNode", + display_name="Meshy: Animate Model", + category="api node/3d/Meshy", + description="Apply a specific animation action to a previously rigged character.", + inputs=[ + IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"), + IO.Int.Input( + "action_id", + default=0, + min=0, + max=696, + tooltip="Visit https://docs.meshy.ai/en/api/animation-library for a list of available values.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.12}""", + ), + ) + + @classmethod + async def execute( + cls, + rig_task_id: str, + action_id: int, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/animations", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyAnimationRequest( + rig_task_id=rig_task_id, + action_id=action_id, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{response.result}"), + response_model=MeshyAnimationResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.result.animation_glb_url, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyTextureNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyTextureNode", + display_name="Meshy: Texture Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Boolean.Input( + "enable_original_uv", + default=True, + tooltip="Use the original UV of the model instead of generating new UVs. " + "When enabled, Meshy preserves existing textures from the uploaded model. " + "If the model has no original UV, the quality of the output might not be as good.", + ), + IO.Boolean.Input("pbr", default=False), + IO.String.Input( + "text_style_prompt", + default="", + multiline=True, + tooltip="Describe your desired texture style of the object using text. Maximum 600 characters." + "Maximum 600 characters. Cannot be used at the same time as 'image_style'.", + ), + IO.Image.Input( + "image_style", + optional=True, + tooltip="A 2d image to guide the texturing process. " + "Can not be used at the same time with 'text_style_prompt'.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + meshy_task_id: str, + enable_original_uv: bool, + pbr: bool, + text_style_prompt: str, + image_style: Input.Image | None = None, + ) -> IO.NodeOutput: + if text_style_prompt and image_style is not None: + raise ValueError("text_style_prompt and image_style cannot be used at the same time") + if not text_style_prompt and image_style is None: + raise ValueError("Either text_style_prompt or image_style is required") + image_style_url = None + if image_style is not None: + image_style_url = (await upload_images_to_comfyapi(cls, image_style, wait_label="Uploading style"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/retexture", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyTextureRequest( + input_task_id=meshy_task_id, + ai_model=model, + enable_original_uv=enable_original_uv, + enable_pbr=pbr, + text_style_prompt=text_style_prompt if text_style_prompt else None, + image_style_url=image_style_url, + ), + ) + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{response.result}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + model_file = f"meshy_model_{response.result}.glb" + await download_url_to_bytesio(result.model_urls.glb, os.path.join(get_output_directory(), model_file)) + return IO.NodeOutput(model_file, response.result) + + +class MeshyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + MeshyTextToModelNode, + MeshyRefineNode, + MeshyImageToModelNode, + MeshyMultiImageToModelNode, + MeshyRigModelNode, + MeshyAnimateModelNode, + MeshyTextureNode, + ] + + +async def comfy_entrypoint() -> MeshyExtension: + return MeshyExtension() diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 05cbb700f..b5d0b461f 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.minimax_api import ( +from comfy_api_nodes.apis.minimax import ( MinimaxFileRetrieveResponse, MiniMaxModel, MinimaxTaskResultResponse, @@ -134,6 +134,9 @@ class MinimaxTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.43}""", + ), ) @classmethod @@ -197,6 +200,9 @@ class MinimaxImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.43}""", + ), ) @classmethod @@ -340,6 +346,20 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]), + expr=""" + ( + $prices := { + "768p": {"6": 0.28, "10": 0.56}, + "1080p": {"6": 0.49} + }; + $resPrices := $lookup($prices, $lowercase(widgets.resolution)); + $price := $lookup($resPrices, $string(widgets.duration)); + {"type":"usd","usd": $price ? $price : 0.43} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 2771e4790..08315fa2b 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -3,7 +3,7 @@ import logging from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.moonvalley import ( MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoRequest, @@ -233,6 +233,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 1.5}""", + ), ) @classmethod @@ -351,6 +355,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 2.25}""", + ), ) @classmethod @@ -471,6 +479,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 1.5}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index a6205a34f..f05aaab7b 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -10,24 +10,18 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( - CreateModelResponseProperties, - Detail, - InputContent, +from comfy_api_nodes.apis.openai import ( InputFileContent, InputImageContent, InputMessage, - InputMessageContentList, InputTextContent, - Item, + ModelResponseProperties, OpenAICreateResponse, - OpenAIResponse, - OutputContent, -) -from comfy_api_nodes.apis.openai_api import ( OpenAIImageEditRequest, OpenAIImageGenerationRequest, OpenAIImageGenerationResponse, + OpenAIResponse, + OutputContent, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -160,6 +154,23 @@ class OpenAIDalle2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "n"]), + expr=""" + ( + $size := widgets.size; + $nRaw := widgets.n; + $n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1; + + $base := + $contains($size, "256x256") ? 0.016 : + $contains($size, "512x512") ? 0.018 : + 0.02; + + {"type":"usd","usd": $round($base * $n, 3)} + ) + """, + ), ) @classmethod @@ -249,7 +260,7 @@ class OpenAIDalle3(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -287,6 +298,25 @@ class OpenAIDalle3(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "quality"]), + expr=""" + ( + $size := widgets.size; + $q := widgets.quality; + $hd := $contains($q, "hd"); + + $price := + $contains($size, "1024x1024") + ? ($hd ? 0.08 : 0.04) + : (($contains($size, "1792x1024") or $contains($size, "1024x1792")) + ? ($hd ? 0.12 : 0.08) + : 0.04); + + {"type":"usd","usd": $price} + ) + """, + ), ) @classmethod @@ -334,9 +364,9 @@ class OpenAIGPTImage1(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIGPTImage1", - display_name="OpenAI GPT Image 1", + display_name="OpenAI GPT Image 1.5", category="api node/image/OpenAI", - description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.", + description="Generates images synchronously via OpenAI's GPT Image endpoint.", inputs=[ IO.String.Input( "prompt", @@ -348,7 +378,7 @@ class OpenAIGPTImage1(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -399,6 +429,7 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Combo.Input( "model", options=["gpt-image-1", "gpt-image-1.5"], + default="gpt-image-1.5", optional=True, ), ], @@ -411,6 +442,28 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), + expr=""" + ( + $ranges := { + "low": [0.011, 0.02], + "medium": [0.046, 0.07], + "high": [0.167, 0.3] + }; + $range := $lookup($ranges, widgets.quality); + $n := widgets.n; + ($n = 1) + ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} + : { + "type":"range_usd", + "min_usd": $range[0], + "max_usd": $range[1], + "format": { "suffix": " x " & $string($n) & "/Run" } + } + ) + """, + ), ) @classmethod @@ -442,8 +495,8 @@ class OpenAIGPTImage1(IO.ComfyNode): files = [] batch_size = image.shape[0] for i in range(batch_size): - single_image = image[i: i + 1] - scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze() + single_image = image[i : i + 1] + scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze() image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) @@ -465,7 +518,7 @@ class OpenAIGPTImage1(IO.ComfyNode): rgba_mask = torch.zeros(height, width, 4, device="cpu") rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() - scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze() + scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze() mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) @@ -566,32 +619,95 @@ class OpenAIChatNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "o4-mini") ? { + "type": "list_usd", + "usd": [0.0011, 0.0044], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o1-pro") ? { + "type": "list_usd", + "usd": [0.15, 0.6], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o1") ? { + "type": "list_usd", + "usd": [0.015, 0.06], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o3-mini") ? { + "type": "list_usd", + "usd": [0.0011, 0.0044], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o3") ? { + "type": "list_usd", + "usd": [0.01, 0.04], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4o") ? { + "type": "list_usd", + "usd": [0.0025, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1-nano") ? { + "type": "list_usd", + "usd": [0.0001, 0.0004], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1-mini") ? { + "type": "list_usd", + "usd": [0.0004, 0.0016], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1") ? { + "type": "list_usd", + "usd": [0.002, 0.008], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5-nano") ? { + "type": "list_usd", + "usd": [0.00005, 0.0004], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5-mini") ? { + "type": "list_usd", + "usd": [0.00025, 0.002], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5") ? { + "type": "list_usd", + "usd": [0.00125, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : {"type": "text", "text": "Token-based"} + ) + """, + ), ) @classmethod - def get_message_content_from_response( - cls, response: OpenAIResponse - ) -> list[OutputContent]: + def get_message_content_from_response(cls, response: OpenAIResponse) -> list[OutputContent]: """Extract message content from the API response.""" for output in response.output: - if output.root.type == "message": - return output.root.content + if output.type == "message": + return output.content raise TypeError("No output message found in response") @classmethod - def get_text_from_message_content( - cls, message_content: list[OutputContent] - ) -> str: + def get_text_from_message_content(cls, message_content: list[OutputContent]) -> str: """Extract text content from message content.""" for content_item in message_content: - if content_item.root.type == "output_text": - return str(content_item.root.text) + if content_item.type == "output_text": + return str(content_item.text) return "No text output found in response" @classmethod - def tensor_to_input_image_content( - cls, image: torch.Tensor, detail_level: Detail = "auto" - ) -> InputImageContent: + def tensor_to_input_image_content(cls, image: torch.Tensor, detail_level: str = "auto") -> InputImageContent: """Convert a tensor to an input image content object.""" return InputImageContent( detail=detail_level, @@ -605,9 +721,9 @@ class OpenAIChatNode(IO.ComfyNode): prompt: str, image: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - ) -> InputMessageContentList: + ) -> list[InputTextContent | InputImageContent | InputFileContent]: """Create a list of input message contents from prompt and optional image.""" - content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ + content_list: list[InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: @@ -619,13 +735,9 @@ class OpenAIChatNode(IO.ComfyNode): type="input_image", ) ) - if files is not None: content_list.extend(files) - - return InputMessageContentList( - root=content_list, - ) + return content_list @classmethod async def execute( @@ -635,7 +747,7 @@ class OpenAIChatNode(IO.ComfyNode): model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, images: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - advanced_options: CreateModelResponseProperties | None = None, + advanced_options: ModelResponseProperties | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -646,36 +758,28 @@ class OpenAIChatNode(IO.ComfyNode): response_model=OpenAIResponse, data=OpenAICreateResponse( input=[ - Item( - root=InputMessage( - content=cls.create_input_message_contents( - prompt, images, files - ), - role="user", - ) + InputMessage( + content=cls.create_input_message_contents(prompt, images, files), + role="user", ), ], store=True, stream=False, model=model, previous_response_id=None, - **( - advanced_options.model_dump(exclude_none=True) - if advanced_options - else {} - ), + **(advanced_options.model_dump(exclude_none=True) if advanced_options else {}), ), ) response_id = create_response.id # Get result output result_response = await poll_op( - cls, - ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), - response_model=OpenAIResponse, - status_extractor=lambda response: response.status, - completed_statuses=["incomplete", "completed"] - ) + cls, + ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), + response_model=OpenAIResponse, + status_extractor=lambda response: response.status, + completed_statuses=["incomplete", "completed"], + ) return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) @@ -796,7 +900,7 @@ class OpenAIChatConfig(IO.ComfyNode): remove depending on model choice. """ return IO.NodeOutput( - CreateModelResponseProperties( + ModelResponseProperties( instructions=instructions, truncation=truncation, max_output_tokens=max_output_tokens, diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 6e1686af0..e17a24ae7 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,7 +1,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.pixverse_api import ( +from comfy_api_nodes.apis.pixverse import ( PixverseTextVideoRequest, PixverseImageVideoRequest, PixverseTransitionVideoRequest, @@ -128,6 +128,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -242,6 +243,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -355,6 +357,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -416,6 +419,33 @@ class PixverseTransitionVideoNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds", "quality", "motion_mode"]), + expr=""" + ( + $prices := { + "5": { + "1080p": {"normal": 1.2, "fast": 1.2}, + "720p": {"normal": 0.6, "fast": 1.2}, + "540p": {"normal": 0.45, "fast": 0.9}, + "360p": {"normal": 0.45, "fast": 0.9} + }, + "8": { + "1080p": {"normal": 1.2, "fast": 1.2}, + "720p": {"normal": 1.2, "fast": 1.2}, + "540p": {"normal": 0.9, "fast": 1.2}, + "360p": {"normal": 0.9, "fast": 1.2} + } + }; + $durPrices := $lookup($prices, $string(widgets.duration_seconds)); + $qualityPrices := $lookup($durPrices, widgets.quality); + $price := $lookup($qualityPrices, widgets.motion_mode); + {"type":"usd","usd": $price ? $price : 0.9} + ) + """, +) + + class PixVerseExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index e3440b946..3a1f32263 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -8,10 +8,12 @@ from typing_extensions import override from comfy.utils import ProgressBar from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.recraft_api import ( +from comfy_api_nodes.apis.recraft import ( RecraftColor, RecraftColorChain, RecraftControls, + RecraftCreateStyleRequest, + RecraftCreateStyleResponse, RecraftImageGenerationRequest, RecraftImageGenerationResponse, RecraftImageSize, @@ -323,6 +325,75 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): return IO.NodeOutput(RecraftStyle(style_id=style_id)) +class RecraftCreateStyleNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftCreateStyleNode", + display_name="Recraft Create Style", + category="api node/image/Recraft", + description="Create a custom style from reference images. " + "Upload 1-5 images to use as style references. " + "Total size of all images is limited to 5 MB.", + inputs=[ + IO.Combo.Input( + "style", + options=["realistic_image", "digital_illustration"], + tooltip="The base style of the generated images.", + ), + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="image", + min=1, + max=5, + ), + ), + ], + outputs=[ + IO.String.Output(display_name="style_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.04}""", + ), + ) + + @classmethod + async def execute( + cls, + style: str, + images: IO.Autogrow.Type, + ) -> IO.NodeOutput: + files = [] + total_size = 0 + max_total_size = 5 * 1024 * 1024 # 5 MB limit + for i, img in enumerate(list(images.values())): + file_bytes = tensor_to_bytesio(img, total_pixels=2048 * 2048, mime_type="image/webp").read() + total_size += len(file_bytes) + if total_size > max_total_size: + raise Exception("Total size of all images exceeds 5 MB limit.") + files.append((f"file{i + 1}", file_bytes)) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/recraft/styles", method="POST"), + response_model=RecraftCreateStyleResponse, + files=files, + data=RecraftCreateStyleRequest(style=style), + content_type="multipart/form-data", + max_retries=1, + ) + + return IO.NodeOutput(response.id) + + class RecraftTextToImageNode(IO.ComfyNode): @classmethod def define_schema(cls): @@ -378,6 +449,10 @@ class RecraftTextToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -391,7 +466,7 @@ class RecraftTextToImageNode(IO.ComfyNode): negative_prompt: str = None, recraft_controls: RecraftControls = None, ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False, max_length=1000) + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: recraft_style = default_style @@ -490,6 +565,10 @@ class RecraftImageToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -591,6 +670,10 @@ class RecraftImageInpaintingNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -692,6 +775,10 @@ class RecraftTextToVectorNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.08 * widgets.n, 2)}""", + ), ) @classmethod @@ -759,6 +846,10 @@ class RecraftVectorizeImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 0.01}""", + ), ) @classmethod @@ -817,6 +908,9 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.04}""", + ), ) @classmethod @@ -883,6 +977,9 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.01}""", + ), ) @classmethod @@ -929,6 +1026,9 @@ class RecraftCrispUpscaleNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.004}""", + ), ) @classmethod @@ -972,6 +1072,9 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @@ -992,6 +1095,7 @@ class RecraftExtension(ComfyExtension): RecraftStyleV3DigitalIllustrationNode, RecraftStyleV3LogoRasterNode, RecraftStyleInfiniteStyleLibrary, + RecraftCreateStyleNode, RecraftColorRGBNode, RecraftControlsNode, ] diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index e60e7a6d6..3ffdc8b90 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -14,7 +14,7 @@ from typing import Optional from io import BytesIO from typing_extensions import override from PIL import Image -from comfy_api_nodes.apis.rodin_api import ( +from comfy_api_nodes.apis.rodin import ( Rodin3DGenerateRequest, Rodin3DGenerateResponse, Rodin3DCheckStatusRequest, @@ -241,6 +241,9 @@ class Rodin3D_Regular(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -294,6 +297,9 @@ class Rodin3D_Detail(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -347,6 +353,9 @@ class Rodin3D_Smooth(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -406,6 +415,9 @@ class Rodin3D_Sketch(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 3c55039c9..573170ba2 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -16,7 +16,7 @@ from enum import Enum from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.runway import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, @@ -184,6 +184,10 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -274,6 +278,10 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -372,6 +380,10 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -457,6 +469,9 @@ class RunwayTextToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 92b225d40..afc18bb25 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -89,6 +89,24 @@ class OpenAIVideoSora2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "size", "duration"]), + expr=""" + ( + $m := widgets.model; + $size := widgets.size; + $dur := widgets.duration; + $isPro := $contains($m, "sora-2-pro"); + $isSora2 := $contains($m, "sora-2"); + $isProSize := ($size = "1024x1792" or $size = "1792x1024"); + $perSec := + $isPro ? ($isProSize ? 0.5 : 0.3) : + $isSora2 ? 0.1 : + ($isProSize ? 0.5 : 0.1); + {"type":"usd","usd": $round($perSec * $dur, 2)} + ) + """, + ), ) @classmethod @@ -131,7 +149,6 @@ class OpenAIVideoSora2(IO.ComfyNode): response_model=Sora2GenerationResponse, status_extractor=lambda x: x.status, poll_interval=8.0, - max_poll_attempts=160, estimated_duration=int(45 * (duration / 4) * model_time_multiplier), ) return IO.NodeOutput( diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index bb7ceed78..5665109cf 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -3,7 +3,7 @@ from typing import Optional from typing_extensions import override from comfy_api.latest import ComfyExtension, Input, IO -from comfy_api_nodes.apis.stability_api import ( +from comfy_api_nodes.apis.stability import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, StabilityAsyncResponse, @@ -127,6 +127,9 @@ class StabilityStableImageUltraNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.08}""", + ), ) @classmethod @@ -264,6 +267,16 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $contains(widgets.model,"large") + ? {"type":"usd","usd":0.065} + : {"type":"usd","usd":0.035} + ) + """, + ), ) @classmethod @@ -382,6 +395,9 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -486,6 +502,9 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -566,6 +585,9 @@ class StabilityUpscaleFastNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.01}""", + ), ) @classmethod @@ -648,6 +670,9 @@ class StabilityTextToAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod @@ -732,6 +757,9 @@ class StabilityAudioToAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod @@ -828,6 +856,9 @@ class StabilityAudioInpaint(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index b04575ad8..8fccde25a 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -2,11 +2,27 @@ import builtins from io import BytesIO import aiohttp -import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import topaz_api +from comfy_api_nodes.apis.topaz import ( + CreateVideoRequest, + CreateVideoRequestSource, + CreateVideoResponse, + ImageAsyncTaskResponse, + ImageDownloadResponse, + ImageEnhanceRequest, + ImageStatusResponse, + OutputInformationVideo, + Resolution, + VideoAcceptResponse, + VideoCompleteUploadRequest, + VideoCompleteUploadRequestPart, + VideoCompleteUploadResponse, + VideoEnhancementFilter, + VideoFrameInterpolationFilter, + VideoStatusResponse, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, @@ -138,7 +154,7 @@ class TopazImageEnhance(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str = "", subject_detection: str = "All", face_enhancement: bool = True, @@ -153,12 +169,14 @@ class TopazImageEnhance(IO.ComfyNode): ) -> IO.NodeOutput: if get_number_of_images(image) != 1: raise ValueError("Only one input image is supported.") - download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png") + download_url = await upload_images_to_comfyapi( + cls, image, max_images=1, mime_type="image/png", total_pixels=4096 * 4096 + ) initial_response = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), - response_model=topaz_api.ImageAsyncTaskResponse, - data=topaz_api.ImageEnhanceRequest( + response_model=ImageAsyncTaskResponse, + data=ImageEnhanceRequest( model=model, prompt=prompt, subject_detection=subject_detection, @@ -180,19 +198,18 @@ class TopazImageEnhance(IO.ComfyNode): await poll_op( cls, poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"), - response_model=topaz_api.ImageStatusResponse, + response_model=ImageStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: x.credits * 0.08, poll_interval=8.0, - max_poll_attempts=160, estimated_duration=60, ) results = await sync_op( cls, ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"), - response_model=topaz_api.ImageDownloadResponse, + response_model=ImageDownloadResponse, monitor_progress=False, ) return IO.NodeOutput(await download_url_to_image_tensor(results.download_url)) @@ -330,7 +347,7 @@ class TopazVideoEnhance(IO.ComfyNode): if target_height % 2 != 0: target_height += 1 filters.append( - topaz_api.VideoEnhancementFilter( + VideoEnhancementFilter( model=UPSCALER_MODELS_MAP[upscaler_model], creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), @@ -339,7 +356,7 @@ class TopazVideoEnhance(IO.ComfyNode): if interpolation_enabled: target_frame_rate = interpolation_frame_rate filters.append( - topaz_api.VideoFrameInterpolationFilter( + VideoFrameInterpolationFilter( model=interpolation_model, slowmo=interpolation_slowmo, fps=interpolation_frame_rate, @@ -350,19 +367,19 @@ class TopazVideoEnhance(IO.ComfyNode): initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/video/", method="POST"), - response_model=topaz_api.CreateVideoResponse, - data=topaz_api.CreateVideoRequest( - source=topaz_api.CreateCreateVideoRequestSource( + response_model=CreateVideoResponse, + data=CreateVideoRequest( + source=CreateVideoRequestSource( container="mp4", size=get_fs_object_size(src_video_stream), duration=int(duration_sec), frameCount=video.get_frame_count(), frameRate=src_frame_rate, - resolution=topaz_api.Resolution(width=src_width, height=src_height), + resolution=Resolution(width=src_width, height=src_height), ), filters=filters, - output=topaz_api.OutputInformationVideo( - resolution=topaz_api.Resolution(width=target_width, height=target_height), + output=OutputInformationVideo( + resolution=Resolution(width=target_width, height=target_height), frameRate=target_frame_rate, audioCodec="AAC", audioTransfer="Copy", @@ -378,7 +395,7 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/accept", method="PATCH", ), - response_model=topaz_api.VideoAcceptResponse, + response_model=VideoAcceptResponse, wait_label="Preparing upload", final_label_on_success="Upload started", ) @@ -401,10 +418,10 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", method="PATCH", ), - response_model=topaz_api.VideoCompleteUploadResponse, - data=topaz_api.VideoCompleteUploadRequest( + response_model=VideoCompleteUploadResponse, + data=VideoCompleteUploadRequest( uploadResults=[ - topaz_api.VideoCompleteUploadRequestPart( + VideoCompleteUploadRequestPart( partNum=1, eTag=upload_etag, ), @@ -416,7 +433,7 @@ class TopazVideoEnhance(IO.ComfyNode): final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), - response_model=topaz_api.VideoStatusResponse, + response_model=VideoStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index e72f8e96a..5abf27b4d 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -5,7 +5,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.tripo_api import ( +from comfy_api_nodes.apis.tripo import ( TripoAnimateRetargetRequest, TripoAnimateRigRequest, TripoConvertModelRequest, @@ -117,6 +117,38 @@ class TripoTextToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "style", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $style := widgets.style; + $hasStyle := ($style != "" and $style != "none"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 20 : ($withTexture ? 20 : 10); + $credits := + $baseCredits + + ($hasStyle ? 5 : 0) + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod @@ -210,6 +242,38 @@ class TripoImageToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "style", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $style := widgets.style; + $hasStyle := ($style != "" and $style != "none"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 30 : ($withTexture ? 30 : 20); + $credits := + $baseCredits + + ($hasStyle ? 5 : 0) + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod @@ -314,6 +378,34 @@ class TripoMultiviewToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 30 : ($withTexture ? 30 : 20); + $credits := + $baseCredits + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod @@ -405,6 +497,15 @@ class TripoTextureNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["texture_quality"]), + expr=""" + ( + $tq := widgets.texture_quality; + {"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)} + ) + """, + ), ) @classmethod @@ -456,6 +557,9 @@ class TripoRefineNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.3}""", + ), ) @classmethod @@ -489,6 +593,9 @@ class TripoRigNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -545,6 +652,9 @@ class TripoRetargetNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1}""", + ), ) @classmethod @@ -638,6 +748,60 @@ class TripoConversionNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "quad", + "face_limit", + "texture_size", + "texture_format", + "force_symmetry", + "flatten_bottom", + "flatten_bottom_threshold", + "pivot_to_center_bottom", + "scale_factor", + "with_animation", + "pack_uv", + "bake", + "part_names", + "fbx_preset", + "export_vertex_colors", + "export_orientation", + "animate_in_place", + ], + ), + expr=""" + ( + $face := (widgets.face_limit != null) ? widgets.face_limit : -1; + $texSize := (widgets.texture_size != null) ? widgets.texture_size : 4096; + $flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0; + $scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1; + $texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg"); + $part := widgets.part_names; + $fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender"); + $orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default"); + $advanced := + widgets.quad or + widgets.force_symmetry or + widgets.flatten_bottom or + widgets.pivot_to_center_bottom or + widgets.with_animation or + widgets.pack_uv or + widgets.bake or + widgets.export_vertex_colors or + widgets.animate_in_place or + ($face != -1) or + ($texSize != 4096) or + ($flatThresh != 0) or + ($scale != 1) or + ($texFmt != "jpeg") or + ($part != "") or + ($fbx != "blender") or + ($orient != "default"); + {"type":"usd","usd": ($advanced ? 0.1 : 0.05)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 13a6bfd91..2a202fc3b 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -4,7 +4,7 @@ from io import BytesIO from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis.veo_api import ( +from comfy_api_nodes.apis.veo import ( VeoGenVidPollRequest, VeoGenVidPollResponse, VeoGenVidRequest, @@ -122,6 +122,10 @@ class VeoVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds"]), + expr="""{"type":"usd","usd": 0.5 * widgets.duration_seconds}""", + ), ) @classmethod @@ -347,6 +351,20 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]), + expr=""" + ( + $m := widgets.model; + $a := widgets.generate_audio; + ($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate")) + ? {"type":"usd","usd": ($a ? 1.2 : 0.8)} + : ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate")) + ? {"type":"usd","usd": ($a ? 3.2 : 1.6)} + : {"type":"range_usd","min_usd":0.8,"max_usd":3.2} + ) + """, + ), ) @@ -420,6 +438,30 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]), + expr=""" + ( + $prices := { + "veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 }, + "veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 } + }; + $m := widgets.model; + $ga := (widgets.generate_audio = "true"); + $seconds := widgets.duration; + $modelKey := + $contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" : + $contains($m, "veo-3.1-generate") ? "veo-3.1-generate" : + ""; + $audioKey := $ga ? "audio" : "no_audio"; + $modelPrices := $lookup($prices, $modelKey); + $pps := $lookup($modelPrices, $audioKey); + ($pps != null) + ? {"type":"usd","usd": $pps * $seconds} + : {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 7a679f0d9..80de14dfe 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -1,22 +1,30 @@ -import logging -from enum import Enum -from typing import Literal, Optional, TypeVar - -import torch -from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.vidu import ( + FrameSetting, + SubjectReference, + TaskCreationRequest, + TaskCreationResponse, + TaskExtendCreationRequest, + TaskMultiFrameCreationRequest, + TaskResult, + TaskStatusResponse, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_video_output, get_number_of_images, poll_op, sync_op, + upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_image_aspect_ratio, validate_image_dimensions, validate_images_aspect_ratio_closeness, + validate_string, + validate_video_duration, ) VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" @@ -25,98 +33,34 @@ VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video" VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video" VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" -R = TypeVar("R") - - -class VideoModelName(str, Enum): - vidu_q1 = "viduq1" - - -class AspectRatio(str, Enum): - r_16_9 = "16:9" - r_9_16 = "9:16" - r_1_1 = "1:1" - - -class Resolution(str, Enum): - r_1080p = "1080p" - - -class MovementAmplitude(str, Enum): - auto = "auto" - small = "small" - medium = "medium" - large = "large" - - -class TaskCreationRequest(BaseModel): - model: VideoModelName = VideoModelName.vidu_q1 - prompt: Optional[str] = Field(None, max_length=1500) - duration: Optional[Literal[5]] = 5 - seed: Optional[int] = Field(0, ge=0, le=2147483647) - aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9 - resolution: Optional[Resolution] = Resolution.r_1080p - movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto - images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") - - -class TaskCreationResponse(BaseModel): - task_id: str = Field(...) - state: str = Field(...) - created_at: str = Field(...) - code: Optional[int] = Field(None, description="Error code") - - -class TaskResult(BaseModel): - id: str = Field(..., description="Creation id") - url: str = Field(..., description="The URL of the generated results, valid for one hour") - cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") - - -class TaskStatusResponse(BaseModel): - state: str = Field(...) - err_code: Optional[str] = Field(None) - creations: list[TaskResult] = Field(..., description="Generated results") - - -def get_video_url_from_response(response) -> Optional[str]: - if response.creations: - return response.creations[0].url - return None - - -def get_video_from_response(response) -> TaskResult: - if not response.creations: - error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" - logging.info(error_msg) - raise RuntimeError(error_msg) - logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url) - return response.creations[0] - async def execute_task( cls: type[IO.ComfyNode], vidu_endpoint: str, - payload: TaskCreationRequest, - estimated_duration: int, -) -> R: - response = await sync_op( + payload: TaskCreationRequest | TaskExtendCreationRequest | TaskMultiFrameCreationRequest, + max_poll_attempts: int = 320, +) -> list[TaskResult]: + task_creation_response = await sync_op( cls, endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), response_model=TaskCreationResponse, data=payload, ) - if response.state == "failed": - error_msg = f"Vidu request failed. Code: {response.code}" - logging.error(error_msg) - raise RuntimeError(error_msg) - return await poll_op( + if task_creation_response.state == "failed": + raise RuntimeError(f"Vidu request failed. Code: {task_creation_response.code}") + response = await poll_op( cls, - ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), + ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % task_creation_response.task_id), response_model=TaskStatusResponse, status_extractor=lambda r: r.state, - estimated_duration=estimated_duration, + progress_extractor=lambda r: r.progress, + max_poll_attempts=max_poll_attempts, ) + if not response.creations: + raise RuntimeError( + f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" + ) + return response.creations class ViduTextToVideoNode(IO.ComfyNode): @@ -127,14 +71,9 @@ class ViduTextToVideoNode(IO.ComfyNode): node_id="ViduTextToVideoNode", display_name="Vidu Text To Video Generation", category="api node/video/Vidu", - description="Generate video from text prompt", + description="Generate video from a text prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.String.Input( "prompt", multiline=True, @@ -163,22 +102,19 @@ class ViduTextToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "aspect_ratio", - options=AspectRatio, - default=AspectRatio.r_16_9, + options=["16:9", "9:16", "1:1"], tooltip="The aspect ratio of the output video", optional=True, ), IO.Combo.Input( "resolution", - options=Resolution, - default=Resolution.r_1080p, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=MovementAmplitude, - default=MovementAmplitude.auto, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -192,6 +128,9 @@ class ViduTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -208,7 +147,7 @@ class ViduTextToVideoNode(IO.ComfyNode): if not prompt: raise ValueError("The prompt field is required and cannot be empty.") payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -216,8 +155,8 @@ class ViduTextToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduImageToVideoNode(IO.ComfyNode): @@ -230,12 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode): category="api node/video/Vidu", description="Generate video from image and optional prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "image", tooltip="An image to be used as the start frame of the generated video", @@ -270,15 +204,13 @@ class ViduImageToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=Resolution, - default=Resolution.r_1080p, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=MovementAmplitude, - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -292,13 +224,16 @@ class ViduImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, duration: int, seed: int, @@ -309,7 +244,7 @@ class ViduImageToVideoNode(IO.ComfyNode): raise ValueError("Only one input image is allowed.") validate_image_aspect_ratio(image, (1, 4), (4, 1)) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -322,8 +257,8 @@ class ViduImageToVideoNode(IO.ComfyNode): max_images=1, mime_type="image/png", ) - results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduReferenceVideoNode(IO.ComfyNode): @@ -334,14 +269,9 @@ class ViduReferenceVideoNode(IO.ComfyNode): node_id="ViduReferenceVideoNode", display_name="Vidu Reference To Video Generation", category="api node/video/Vidu", - description="Generate video from multiple images and prompt", + description="Generate video from multiple images and a prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "images", tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).", @@ -374,22 +304,19 @@ class ViduReferenceVideoNode(IO.ComfyNode): ), IO.Combo.Input( "aspect_ratio", - options=AspectRatio, - default=AspectRatio.r_16_9, + options=["16:9", "9:16", "1:1"], tooltip="The aspect ratio of the output video", optional=True, ), IO.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -403,13 +330,16 @@ class ViduReferenceVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod async def execute( cls, model: str, - images: torch.Tensor, + images: Input.Image, prompt: str, duration: int, seed: int, @@ -426,7 +356,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): validate_image_aspect_ratio(image, (1, 4), (4, 1)) validate_image_dimensions(image, min_width=128, min_height=128) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -440,8 +370,8 @@ class ViduReferenceVideoNode(IO.ComfyNode): max_images=7, mime_type="image/png", ) - results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduStartEndToVideoNode(IO.ComfyNode): @@ -454,12 +384,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): category="api node/video/Vidu", description="Generate a video from start and end frames and a prompt", inputs=[ - IO.Combo.Input( - "model", - options=[model.value for model in VideoModelName], - default=VideoModelName.vidu_q1.value, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "first_frame", tooltip="Start frame", @@ -497,15 +422,13 @@ class ViduStartEndToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, ), IO.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, ), @@ -519,14 +442,17 @@ class ViduStartEndToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod async def execute( cls, model: str, - first_frame: torch.Tensor, - end_frame: torch.Tensor, + first_frame: Input.Image, + end_frame: Input.Image, prompt: str, duration: int, seed: int, @@ -535,7 +461,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -546,8 +472,1013 @@ class ViduStartEndToVideoNode(IO.ComfyNode): (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ] - results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2TextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2TextToVideoNode", + display_name="Vidu2 Text-to-Video Generation", + category="api node/video/Vidu", + description="Generate video from a text prompt", + inputs=[ + IO.Combo.Input("model", options=["viduq2"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation, with a maximum length of 2000 characters.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "3:4", "4:3", "1:1"]), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Boolean.Input( + "background_music", + default=False, + tooltip="Whether to add background music to the generated video.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $is1080 := widgets.resolution = "1080p"; + $base := $is1080 ? 0.1 : 0.075; + $perSec := $is1080 ? 0.05 : 0.025; + {"type":"usd","usd": $base + $perSec * (widgets.duration - 1)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + background_music: bool, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + results = await execute_task( + cls, + VIDU_TEXT_TO_VIDEO, + TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + bgm=background_music, + ), + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2ImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2ImageToVideoNode", + display_name="Vidu2 Image-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from an image and an optional prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), + IO.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for video generation (max 2000 characters).", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + ), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $m := widgets.model; + $d := widgets.duration; + $is1080 := widgets.resolution = "1080p"; + $contains($m, "pro-fast") + ? ( + $base := $is1080 ? 0.08 : 0.04; + $perSec := $is1080 ? 0.02 : 0.01; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "pro") + ? ( + $base := $is1080 ? 0.275 : 0.075; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "turbo") + ? ( + $is1080 + ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} + : ( + $d <= 1 ? {"type":"usd","usd": 0.04} + : $d <= 2 ? {"type":"usd","usd": 0.05} + : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} + ) + ) + : {"type":"usd","usd": 0.04} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) > 1: + raise ValueError("Only one input image is allowed.") + validate_image_aspect_ratio(image, (1, 4), (4, 1)) + validate_string(prompt, max_length=2000) + results = await execute_task( + cls, + VIDU_IMAGE_TO_VIDEO, + TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + images=await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type="image/png", + ), + ), + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2ReferenceVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2ReferenceVideoNode", + display_name="Vidu2 Reference-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from multiple reference images and a prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2"]), + IO.Autogrow.Input( + "subjects", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("reference_images"), + names=["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"], + min=1, + ), + tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). " + "Reference them in prompts via @subject{subject_id}.", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="When enabled, the video will include generated speech and background music " + "based on the prompt.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled video will contain generated speech and background music based on the prompt.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["audio", "duration", "resolution"]), + expr=""" + ( + $is1080 := widgets.resolution = "1080p"; + $base := $is1080 ? 0.375 : 0.125; + $perSec := $is1080 ? 0.05 : 0.025; + $audioCost := widgets.audio = true ? 0.075 : 0; + {"type":"usd","usd": $base + $perSec * (widgets.duration - 1) + $audioCost} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + subjects: IO.Autogrow.Type, + prompt: str, + audio: bool, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + total_images = 0 + for i in subjects: + if get_number_of_images(subjects[i]) > 3: + raise ValueError("Maximum number of images per subject is 3.") + for im in subjects[i]: + total_images += 1 + validate_image_aspect_ratio(im, (1, 4), (4, 1)) + validate_image_dimensions(im, min_width=128, min_height=128) + if total_images > 7: + raise ValueError("Too many reference images; the maximum allowed is 7.") + subjects_param: list[SubjectReference] = [] + for i in subjects: + subjects_param.append( + SubjectReference( + id=i, + images=await upload_images_to_comfyapi( + cls, + subjects[i], + max_images=3, + mime_type="image/png", + wait_label=f"Uploading reference images for {i}", + ), + ), + ) + payload = TaskCreationRequest( + model=model, + prompt=prompt, + audio=audio, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + subjects=subjects_param, + ) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2StartEndToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2StartEndToVideoNode", + display_name="Vidu2 Start/End Frame-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from a start frame, an end frame, and a prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), + IO.Image.Input("first_frame"), + IO.Image.Input("end_frame"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt description (max 2000 characters).", + ), + IO.Int.Input( + "duration", + default=5, + min=2, + max=8, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $m := widgets.model; + $d := widgets.duration; + $is1080 := widgets.resolution = "1080p"; + $contains($m, "pro-fast") + ? ( + $base := $is1080 ? 0.08 : 0.04; + $perSec := $is1080 ? 0.02 : 0.01; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "pro") + ? ( + $base := $is1080 ? 0.275 : 0.075; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "turbo") + ? ( + $is1080 + ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} + : ( + $d <= 2 ? {"type":"usd","usd": 0.05} + : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} + ) + ) + : {"type":"usd","usd": 0.04} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + first_frame: Input.Image, + end_frame: Input.Image, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2000) + if get_number_of_images(first_frame) > 1: + raise ValueError("Only one input image is allowed for `first_frame`.") + if get_number_of_images(end_frame) > 1: + raise ValueError("Only one input image is allowed for `end_frame`.") + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + payload = TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + images=[ + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + for frame in (first_frame, end_frame) + ], + ) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class ViduExtendVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduExtendVideoNode", + display_name="Vidu Video Extension", + category="api node/video/Vidu", + description="Extend an existing video by generating additional frames.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq2-pro", + [ + IO.Int.Input( + "duration", + default=4, + min=1, + max=7, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the extended video in seconds.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + ], + ), + IO.DynamicCombo.Option( + "viduq2-turbo", + [ + IO.Int.Input( + "duration", + default=4, + min=1, + max=7, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the extended video in seconds.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + ], + ), + ], + tooltip="Model to use for video extension.", + ), + IO.Video.Input( + "video", + tooltip="The source video to extend.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for the extended video (max 2000 characters).", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Image.Input("end_frame", optional=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), + expr=""" + ( + $m := widgets.model; + $d := $lookup(widgets, "model.duration"); + $res := $lookup(widgets, "model.resolution"); + $contains($m, "pro") + ? ( + $base := $lookup({"720p": 0.15, "1080p": 0.3}, $res); + $perSec := $lookup({"720p": 0.05, "1080p": 0.075}, $res); + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : ( + $base := $lookup({"720p": 0.075, "1080p": 0.2}, $res); + $perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res); + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + video: Input.Video, + prompt: str, + seed: int, + end_frame: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2000) + validate_video_duration(video, min_duration=4, max_duration=55) + image_url = None + if end_frame is not None: + validate_image_aspect_ratio(end_frame, (1, 4), (4, 1)) + validate_image_dimensions(end_frame, min_width=128, min_height=128) + image_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") + results = await execute_task( + cls, + "/proxy/vidu/extend", + TaskExtendCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + resolution=model["resolution"], + video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading video"), + images=[image_url] if image_url else None, + ), + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +def _generate_frame_inputs(count: int) -> list: + """Generate input widgets for a given number of frames.""" + inputs = [] + for i in range(1, count + 1): + inputs.extend( + [ + IO.String.Input( + f"prompt{i}", + multiline=True, + default="", + tooltip=f"Text prompt for frame {i} transition.", + ), + IO.Image.Input( + f"end_image{i}", + tooltip=f"End frame image for segment {i}. Aspect ratio must be between 1:4 and 4:1.", + ), + IO.Int.Input( + f"duration{i}", + default=4, + min=2, + max=7, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip=f"Duration for segment {i} in seconds.", + ), + ] + ) + return inputs + + +class ViduMultiFrameVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduMultiFrameVideoNode", + display_name="Vidu Multi-Frame Video Generation", + category="api node/video/Vidu", + description="Generate a video with multiple keyframe transitions.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]), + IO.Image.Input( + "start_image", + tooltip="The starting frame image. Aspect ratio must be between 1:4 and 4:1.", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.DynamicCombo.Input( + "frames", + options=[ + IO.DynamicCombo.Option("2", _generate_frame_inputs(2)), + IO.DynamicCombo.Option("3", _generate_frame_inputs(3)), + IO.DynamicCombo.Option("4", _generate_frame_inputs(4)), + IO.DynamicCombo.Option("5", _generate_frame_inputs(5)), + IO.DynamicCombo.Option("6", _generate_frame_inputs(6)), + IO.DynamicCombo.Option("7", _generate_frame_inputs(7)), + IO.DynamicCombo.Option("8", _generate_frame_inputs(8)), + IO.DynamicCombo.Option("9", _generate_frame_inputs(9)), + ], + tooltip="Number of keyframe transitions (2-9).", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model", + "resolution", + "frames", + "frames.duration1", + "frames.duration2", + "frames.duration3", + "frames.duration4", + "frames.duration5", + "frames.duration6", + "frames.duration7", + "frames.duration8", + "frames.duration9", + ] + ), + expr=""" + ( + $m := widgets.model; + $n := $number(widgets.frames); + $is1080 := widgets.resolution = "1080p"; + $d1 := $lookup(widgets, "frames.duration1"); + $d2 := $lookup(widgets, "frames.duration2"); + $d3 := $n >= 3 ? $lookup(widgets, "frames.duration3") : 0; + $d4 := $n >= 4 ? $lookup(widgets, "frames.duration4") : 0; + $d5 := $n >= 5 ? $lookup(widgets, "frames.duration5") : 0; + $d6 := $n >= 6 ? $lookup(widgets, "frames.duration6") : 0; + $d7 := $n >= 7 ? $lookup(widgets, "frames.duration7") : 0; + $d8 := $n >= 8 ? $lookup(widgets, "frames.duration8") : 0; + $d9 := $n >= 9 ? $lookup(widgets, "frames.duration9") : 0; + $totalDuration := $d1 + $d2 + $d3 + $d4 + $d5 + $d6 + $d7 + $d8 + $d9; + $contains($m, "pro") + ? ( + $base := $is1080 ? 0.3 : 0.15; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $n * $base + $perSec * $totalDuration} + ) + : ( + $base := $is1080 ? 0.2 : 0.075; + $perSec := $is1080 ? 0.05 : 0.025; + {"type":"usd","usd": $n * $base + $perSec * $totalDuration} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + start_image: Input.Image, + seed: int, + resolution: str, + frames: dict, + ) -> IO.NodeOutput: + validate_image_aspect_ratio(start_image, (1, 4), (4, 1)) + frame_count = int(frames["frames"]) + image_settings: list[FrameSetting] = [] + for i in range(1, frame_count + 1): + validate_image_aspect_ratio(frames[f"end_image{i}"], (1, 4), (4, 1)) + validate_string(frames[f"prompt{i}"], max_length=2000) + start_image_url = await upload_image_to_comfyapi( + cls, + start_image, + mime_type="image/png", + wait_label="Uploading start image", + ) + for i in range(1, frame_count + 1): + image_settings.append( + FrameSetting( + prompt=frames[f"prompt{i}"], + key_image=await upload_image_to_comfyapi( + cls, + frames[f"end_image{i}"], + mime_type="image/png", + wait_label=f"Uploading end image({i})", + ), + duration=frames[f"duration{i}"], + ) + ) + results = await execute_task( + cls, + "/proxy/vidu/multiframe", + TaskMultiFrameCreationRequest( + model=model, + seed=seed, + resolution=resolution, + start_image=start_image_url, + image_settings=image_settings, + ), + max_poll_attempts=480 * frame_count, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu3TextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu3TextToVideoNode", + display_name="Vidu Q3 Text-to-Video Generation", + category="api node/video/Vidu", + description="Generate video from a text prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq3-pro", + [ + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "3:4", "4:3", "1:1"], + tooltip="The aspect ratio of the output video.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + ], + tooltip="Model to use for video generation.", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation, with a maximum length of 2000 characters.", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $base := $lookup({"720p": 0.075, "1080p": 0.1}, $res); + $perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res); + {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + results = await execute_task( + cls, + VIDU_TEXT_TO_VIDEO, + TaskCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + aspect_ratio=model["aspect_ratio"], + resolution=model["resolution"], + audio=model["audio"], + ), + max_poll_attempts=640, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu3ImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu3ImageToVideoNode", + display_name="Vidu Q3 Image-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from an image and an optional prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq3-pro", + [ + IO.Combo.Input( + "resolution", + options=["720p", "1080p", "2K"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + ], + tooltip="Model to use for video generation.", + ), + IO.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for video generation (max 2000 characters).", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res); + $perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res); + {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + image: Input.Image, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + validate_image_aspect_ratio(image, (1, 4), (4, 1)) + validate_string(prompt, max_length=2000) + results = await execute_task( + cls, + VIDU_IMAGE_TO_VIDEO, + TaskCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + resolution=model["resolution"], + audio=model["audio"], + images=[await upload_image_to_comfyapi(cls, image)], + ), + max_poll_attempts=720, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduExtension(ComfyExtension): @@ -558,6 +1489,14 @@ class ViduExtension(ComfyExtension): ViduImageToVideoNode, ViduReferenceVideoNode, ViduStartEndToVideoNode, + Vidu2TextToVideoNode, + Vidu2ImageToVideoNode, + Vidu2ReferenceVideoNode, + Vidu2StartEndToVideoNode, + ViduExtendVideoNode, + ViduMultiFrameVideoNode, + Vidu3TextToVideoNode, + Vidu3ImageToVideoNode, ] diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 1675fd863..a1355d4f1 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -13,7 +13,9 @@ from comfy_api_nodes.util import ( poll_op, sync_op, tensor_to_base64_string, + upload_video_to_comfyapi, validate_audio_duration, + validate_video_duration, ) @@ -41,6 +43,12 @@ class Image2VideoInputField(BaseModel): audio_url: str | None = Field(None) +class Reference2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: str | None = Field(None) + reference_video_urls: list[str] = Field(...) + + class Txt2ImageParametersField(BaseModel): size: str = Field(...) n: int = Field(1, description="Number of images to generate.") # we support only value=1 @@ -76,6 +84,14 @@ class Image2VideoParametersField(BaseModel): shot_type: str = Field("single") +class Reference2VideoParametersField(BaseModel): + size: str = Field(...) + duration: int = Field(5, ge=5, le=15) + shot_type: str = Field("single") + seed: int = Field(..., ge=0, le=2147483647) + watermark: bool = Field(False) + + class Text2ImageTaskCreationRequest(BaseModel): model: str = Field(...) input: Text2ImageInputField = Field(...) @@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel): parameters: Image2VideoParametersField = Field(...) +class Reference2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Reference2VideoInputField = Field(...) + parameters: Reference2VideoParametersField = Field(...) + + class TaskCreationOutputField(BaseModel): task_id: str = Field(...) task_status: str = Field(...) @@ -222,6 +244,9 @@ class WanTextToImageApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -341,6 +366,9 @@ class WanImageToImageApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -498,6 +526,17 @@ class WanTextToVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "size"]), + expr=""" + ( + $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; + $resKey := $substringBefore(widgets.size, ":"); + $pps := $lookup($ppsTable, $resKey); + { "type": "usd", "usd": $round($pps * widgets.duration, 2) } + ) + """, + ), ) @classmethod @@ -659,6 +698,16 @@ class WanImageToVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; + $pps := $lookup($ppsTable, widgets.resolution); + { "type": "usd", "usd": $round($pps * widgets.duration, 2) } + ) + """, + ), ) @classmethod @@ -721,6 +770,159 @@ class WanImageToVideoApi(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) +class WanReferenceVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WanReferenceVideoApi", + display_name="Wan Reference to Video", + category="api node/video/Wan", + description="Use the character and voice from input videos, combined with a prompt, " + "to generate a new video that maintains character consistency.", + inputs=[ + IO.Combo.Input("model", options=["wan2.6-r2v"]), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese. " + "Use identifiers such as `character1` and `character2` to refer to the reference characters.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative prompt describing what to avoid.", + ), + IO.Autogrow.Input( + "reference_videos", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("reference_video"), + names=["character1", "character2", "character3"], + min=1, + ), + ), + IO.Combo.Input( + "size", + options=[ + "720p: 1:1 (960x960)", + "720p: 16:9 (1280x720)", + "720p: 9:16 (720x1280)", + "720p: 4:3 (1088x832)", + "720p: 3:4 (832x1088)", + "1080p: 1:1 (1440x1440)", + "1080p: 16:9 (1920x1080)", + "1080p: 9:16 (1080x1920)", + "1080p: 4:3 (1632x1248)", + "1080p: 3:4 (1248x1632)", + ], + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=10, + step=5, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts.", + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add an AI-generated watermark to the result.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "duration"]), + expr=""" + ( + $rate := $contains(widgets.size, "1080p") ? 0.15 : 0.10; + $inputMin := 2 * $rate; + $inputMax := 5 * $rate; + $outputPrice := widgets.duration * $rate; + { + "type": "range_usd", + "min_usd": $inputMin + $outputPrice, + "max_usd": $inputMax + $outputPrice + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str, + reference_videos: IO.Autogrow.Type, + size: str, + duration: int, + seed: int, + shot_type: str, + watermark: bool, + ): + reference_video_urls = [] + for i in reference_videos: + validate_video_duration(reference_videos[i], min_duration=2, max_duration=30) + for i in reference_videos: + reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i])) + width, height = RES_IN_PARENS.search(size).groups() + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Reference2VideoTaskCreationRequest( + model=model, + input=Reference2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls + ), + parameters=Reference2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + shot_type=shot_type, + watermark=watermark, + seed=seed, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + poll_interval=6, + max_poll_attempts=280, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + class WanApiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -729,6 +931,7 @@ class WanApiExtension(ComfyExtension): WanImageToImageApi, WanTextToVideoApi, WanImageToVideoApi, + WanReferenceVideoApi, ] diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py new file mode 100644 index 000000000..c59fafd3b --- /dev/null +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -0,0 +1,178 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.wavespeed import ( + FlashVSRRequest, + TaskCreatedResponse, + TaskResultResponse, + SeedVR2ImageRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_video_output, + poll_op, + sync_op, + upload_video_to_comfyapi, + validate_container_format_is_mp4, + validate_video_duration, + upload_images_to_comfyapi, + get_number_of_images, + download_url_to_image_tensor, +) + + +class WavespeedFlashVSRNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedFlashVSRNode", + display_name="FlashVSR Video Upscale", + category="api node/video/WaveSpeed", + description="Fast, high-quality video upscaler that " + "boosts resolution and restores clarity for low-resolution or blurry footage.", + inputs=[ + IO.Video.Input("video"), + IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]), + expr=""" + ( + $price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032}; + { + "type":"usd", + "usd": $lookup($price_for_1sec, widgets.target_resolution), + "format":{"suffix": "/second", "approximate": true} + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + target_resolution: str, + ) -> IO.NodeOutput: + validate_container_format_is_mp4(video) + validate_video_duration(video, min_duration=5, max_duration=60 * 10) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"), + response_model=TaskCreatedResponse, + data=FlashVSRRequest( + target_resolution=target_resolution.lower(), + video=await upload_video_to_comfyapi(cls, video), + duration=video.get_duration(), + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0])) + + +class WavespeedImageUpscaleNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedImageUpscaleNode", + display_name="WaveSpeed Image Upscale", + category="api node/image/WaveSpeed", + description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", + inputs=[ + IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), + IO.Image.Input("image"), + IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $prices := {"seedvr2": 0.01, "ultimate": 0.06}; + {"type":"usd", "usd": $lookup($prices, widgets.model)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + target_resolution: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if model == "SeedVR2": + model_path = "seedvr2/image" + else: + model_path = "ultimate-image-upscaler" + initial_res = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"), + response_model=TaskCreatedResponse, + data=SeedVR2ImageRequest( + target_resolution=target_resolution.lower(), + image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0])) + + +class WavespeedExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + WavespeedFlashVSRNode, + WavespeedImageUpscaleNode, + ] + + +async def comfy_entrypoint() -> WavespeedExtension: + return WavespeedExtension() diff --git a/comfy_api_nodes/redocly-dev.yaml b/comfy_api_nodes/redocly-dev.yaml deleted file mode 100644 index d9e3cab70..000000000 --- a/comfy_api_nodes/redocly-dev.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. -# This is used for development purposes to generate stubs for unreleased API endpoints. -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes'] - matchStrategy: all diff --git a/comfy_api_nodes/redocly.yaml b/comfy_api_nodes/redocly.yaml deleted file mode 100644 index d102345b1..000000000 --- a/comfy_api_nodes/redocly.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. - -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes', 'Released'] - matchStrategy: all diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 4cc22abfb..c3c9ff4bf 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -11,7 +11,9 @@ from .conversions import ( audio_input_to_mp3, audio_to_base64_string, bytesio_to_image_tensor, + convert_mask_to_image, downscale_image_tensor, + downscale_image_tensor_by_max_side, image_tensor_pair_to_batch, pil_to_bytesio, resize_mask_to_image, @@ -32,6 +34,7 @@ from .download_helpers import ( from .upload_helpers import ( upload_audio_to_comfyapi, upload_file_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, upload_video_to_comfyapi, ) @@ -60,6 +63,7 @@ __all__ = [ # Upload helpers "upload_audio_to_comfyapi", "upload_file_to_comfyapi", + "upload_image_to_comfyapi", "upload_images_to_comfyapi", "upload_video_to_comfyapi", # Download helpers @@ -72,7 +76,9 @@ __all__ = [ "audio_input_to_mp3", "audio_to_base64_string", "bytesio_to_image_tensor", + "convert_mask_to_image", "downscale_image_tensor", + "downscale_image_tensor_by_max_side", "image_tensor_pair_to_batch", "pil_to_bytesio", "resize_mask_to_image", diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index f372ec7b5..8a1259506 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -141,7 +141,7 @@ async def poll_op( queued_statuses: list[str | int] | None = None, data: BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 120, + max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, @@ -238,7 +238,7 @@ async def poll_op_raw( queued_statuses: list[str | int] | None = None, data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 120, + max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, max_retries_per_poll: int = 3, retry_delay_per_poll: float = 1.0, diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index d64239c86..3e37e8a8c 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -55,16 +55,15 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, - name: str | None = None, - total_pixels: int = 2048 * 2048, + *, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> BytesIO: """Converts a torch.Tensor image to a named BytesIO object. Args: image: Input torch.Tensor image. - name: Optional filename for the BytesIO object. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). Returns: @@ -75,17 +74,18 @@ def tensor_to_bytesio( pil_image = tensor_to_pil(image, total_pixels=total_pixels) img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) - img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + img_binary.name = f"{uuid.uuid4()}.{mimetype_to_extension(mime_type)}" return img_binary -def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: +def tensor_to_pil(image: torch.Tensor, total_pixels: int | None = 2048 * 2048) -> Image.Image: """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" if len(image.shape) > 3: image = image[0] # TODO: remove alpha if not allowed and present input_tensor = image.cpu() - input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() + if total_pixels is not None: + input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() image_np = (input_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) return img @@ -93,14 +93,14 @@ def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image def tensor_to_base64_string( image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> str: """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. Args: image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). Returns: @@ -144,16 +144,31 @@ def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) return s +def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: + """Downscale input image tensor so the largest dimension is at most max_side pixels.""" + samples = image.movedim(-1, 1) + height, width = samples.shape[2], samples.shape[3] + max_dim = max(width, height) + if max_dim <= max_side: + return image + scale_by = max_side / max_dim + new_width = round(width * scale_by) + new_height = round(height * scale_by) + s = common_upscale(samples, new_width, new_height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + def tensor_to_data_uri( image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> str: """Converts a tensor image to a Data URI string. Args: image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). Returns: @@ -451,6 +466,12 @@ def resize_mask_to_image( return mask +def convert_mask_to_image(mask: Input.Image) -> torch.Tensor: + """Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.""" + mask = mask.unsqueeze(-1) + return torch.cat([mask] * 3, dim=-1) + + def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index b8d33f4d1..3153f2b98 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -43,27 +43,41 @@ class UploadResponse(BaseModel): async def upload_images_to_comfyapi( cls: type[IO.ComfyNode], - image: torch.Tensor, + image: torch.Tensor | list[torch.Tensor], *, max_images: int = 8, mime_type: str | None = None, wait_label: str | None = "Uploading", show_batch_index: bool = True, + total_pixels: int | None = 2048 * 2048, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs. To upload multiple images, stack them in the batch dimension first. """ + tensors: list[torch.Tensor] = [] + if isinstance(image, list): + for img in image: + is_batch = len(img.shape) > 3 + if is_batch: + tensors.extend(img[i] for i in range(img.shape[0])) + else: + tensors.append(img) + else: + is_batch = len(image.shape) > 3 + if is_batch: + tensors.extend(image[i] for i in range(image.shape[0])) + else: + tensors.append(image) + # if batched, try to upload each file if max_images is greater than 0 download_urls: list[str] = [] - is_batch = len(image.shape) > 3 - batch_len = image.shape[0] if is_batch else 1 - num_to_upload = min(batch_len, max_images) + num_to_upload = min(len(tensors), max_images) batch_start_ts = time.monotonic() for idx in range(num_to_upload): - tensor = image[idx] if is_batch else image - img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + tensor = tensors[idx] + img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type) effective_label = wait_label if wait_label and show_batch_index and num_to_upload > 1: @@ -74,6 +88,28 @@ async def upload_images_to_comfyapi( return download_urls +async def upload_image_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + mime_type: str | None = None, + wait_label: str | None = "Uploading", + total_pixels: int = 2048 * 2048, +) -> str: + """Uploads a single image to ComfyUI API and returns its download URL.""" + return ( + await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type=mime_type, + wait_label=wait_label, + show_batch_index=False, + total_pixels=total_pixels, + ) + )[0] + + async def upload_audio_to_comfyapi( cls: type[IO.ComfyNode], audio: Input.Audio, @@ -81,7 +117,6 @@ async def upload_audio_to_comfyapi( container_format: str = "mp4", codec_name: str = "aac", mime_type: str = "audio/mp4", - filename: str = "uploaded_audio.mp4", ) -> str: """ Uploads a single audio input to ComfyUI API and returns its download URL. @@ -91,7 +126,7 @@ async def upload_audio_to_comfyapi( waveform: torch.Tensor = audio["waveform"] audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) - return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) + return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type) async def upload_video_to_comfyapi( @@ -119,7 +154,7 @@ async def upload_video_to_comfyapi( raise ValueError(f"Could not verify video duration from source: {e}") from e upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" + filename = f"{uuid.uuid4()}.{container.value.lower()}" # Convert VideoInput to BytesIO using specified container/codec video_bytes_io = BytesIO() diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 59fb49357..bf091a448 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -14,8 +14,9 @@ class JobStatus: IN_PROGRESS = 'in_progress' COMPLETED = 'completed' FAILED = 'failed' + CANCELLED = 'cancelled' - ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED] + ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] # Media types that can be previewed in the frontend @@ -94,12 +95,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: status_info = history_item.get('status', {}) status_str = status_info.get('status_str') if status_info else None - if status_str == 'success': - status = JobStatus.COMPLETED - elif status_str == 'error': - status = JobStatus.FAILED - else: - status = JobStatus.COMPLETED outputs = history_item.get('outputs', {}) outputs_count, preview_output = get_outputs_summary(outputs) @@ -107,6 +102,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: execution_error = None execution_start_time = None execution_end_time = None + was_interrupted = False if status_info: messages = status_info.get('messages', []) for entry in messages: @@ -119,6 +115,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: execution_end_time = event_data.get('timestamp') if event_name == 'execution_error': execution_error = event_data + elif event_name == 'execution_interrupted': + was_interrupted = True + + if status_str == 'success': + status = JobStatus.COMPLETED + elif status_str == 'error': + status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED + else: + status = JobStatus.COMPLETED job = prune_dict({ 'id': prompt_id, @@ -166,9 +171,10 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: continue for item in items: + count += 1 + if not isinstance(item, dict): continue - count += 1 if preview_output is None and is_previewable(media_type, item): enriched = { @@ -268,13 +274,13 @@ def get_all_jobs( for item in queued: jobs.append(normalize_queue_item(item, JobStatus.PENDING)) - include_completed = JobStatus.COMPLETED in status_filter - include_failed = JobStatus.FAILED in status_filter - if include_completed or include_failed: + history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED} + requested_history_statuses = history_statuses & set(status_filter) + if requested_history_statuses: for prompt_id, history_item in history.items(): - is_failed = history_item.get('status', {}).get('status_str') == 'error' - if (is_failed and include_failed) or (not is_failed and include_completed): - jobs.append(normalize_history_item(prompt_id, history_item)) + job = normalize_history_item(prompt_id, history_item) + if job.get('status') in requested_history_statuses: + jobs.append(job) if workflow_id: jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py index edd5dadd4..4fc511d2c 100644 --- a/comfy_extras/nodes_align_your_steps.py +++ b/comfy_extras/nodes_align_your_steps.py @@ -28,6 +28,7 @@ class AlignYourStepsScheduler(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="AlignYourStepsScheduler", + search_aliases=["AYS scheduler"], category="sampling/custom_sampling/schedulers", inputs=[ io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]), diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py index c0e494c2a..67c4e2ed0 100644 --- a/comfy_extras/nodes_attention_multiply.py +++ b/comfy_extras/nodes_attention_multiply.py @@ -71,6 +71,7 @@ class CLIPAttentionMultiply(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="CLIPAttentionMultiply", + search_aliases=["clip attention scale", "text encoder attention"], category="_for_testing/attention_experiments", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index c7916443c..271b75fbd 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -69,6 +69,7 @@ class VAEEncodeAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VAEEncodeAudio", + search_aliases=["audio to latent"], display_name="VAE Encode Audio", category="latent/audio", inputs=[ @@ -97,6 +98,7 @@ class VAEDecodeAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VAEDecodeAudio", + search_aliases=["latent to audio"], display_name="VAE Decode Audio", category="latent/audio", inputs=[ @@ -112,7 +114,7 @@ class VAEDecodeAudio(IO.ComfyNode): std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]}) decode = execute # TODO: remove @@ -122,6 +124,7 @@ class SaveAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveAudio", + search_aliases=["export flac"], display_name="Save Audio (FLAC)", category="audio", inputs=[ @@ -146,6 +149,7 @@ class SaveAudioMP3(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveAudioMP3", + search_aliases=["export mp3"], display_name="Save Audio (MP3)", category="audio", inputs=[ @@ -173,6 +177,7 @@ class SaveAudioOpus(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveAudioOpus", + search_aliases=["export opus"], display_name="Save Audio (Opus)", category="audio", inputs=[ @@ -200,6 +205,7 @@ class PreviewAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="PreviewAudio", + search_aliases=["play audio"], display_name="Preview Audio", category="audio", inputs=[ @@ -259,6 +265,7 @@ class LoadAudio(IO.ComfyNode): files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) return IO.Schema( node_id="LoadAudio", + search_aliases=["import audio", "open audio", "audio file"], display_name="Load Audio", category="audio", inputs=[ @@ -296,6 +303,7 @@ class RecordAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="RecordAudio", + search_aliases=["microphone input", "audio capture", "voice input"], display_name="Record Audio", category="audio", inputs=[ @@ -320,6 +328,7 @@ class TrimAudioDuration(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TrimAudioDuration", + search_aliases=["cut audio", "audio clip", "shorten audio"], display_name="Trim Audio Duration", description="Trim audio tensor into chosen time range.", category="audio", @@ -372,6 +381,7 @@ class SplitAudioChannels(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SplitAudioChannels", + search_aliases=["stereo to mono"], display_name="Split Audio Channels", description="Separates the audio into left and right channels.", category="audio", @@ -399,6 +409,58 @@ class SplitAudioChannels(IO.ComfyNode): separate = execute # TODO: remove +class JoinAudioChannels(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="JoinAudioChannels", + display_name="Join Audio Channels", + description="Joins left and right mono audio channels into a stereo audio.", + category="audio", + inputs=[ + IO.Audio.Input("audio_left"), + IO.Audio.Input("audio_right"), + ], + outputs=[ + IO.Audio.Output(display_name="audio"), + ], + ) + + @classmethod + def execute(cls, audio_left, audio_right) -> IO.NodeOutput: + waveform_left = audio_left["waveform"] + sample_rate_left = audio_left["sample_rate"] + waveform_right = audio_right["waveform"] + sample_rate_right = audio_right["sample_rate"] + + if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1: + raise ValueError("AudioJoin: Both input audios must be mono.") + + # Handle different sample rates by resampling to the higher rate + waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates( + waveform_left, sample_rate_left, waveform_right, sample_rate_right + ) + + # Handle different lengths by trimming to the shorter length + length_left = waveform_left.shape[-1] + length_right = waveform_right.shape[-1] + + if length_left != length_right: + min_length = min(length_left, length_right) + if length_left > min_length: + logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.") + waveform_left = waveform_left[..., :min_length] + if length_right > min_length: + logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.") + waveform_right = waveform_right[..., :min_length] + + # Join the channels into stereo + left_channel = waveform_left[..., 0:1, :] + right_channel = waveform_right[..., 0:1, :] + stereo_waveform = torch.cat([left_channel, right_channel], dim=1) + + return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate}) + def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): if sample_rate_1 != sample_rate_2: @@ -420,6 +482,7 @@ class AudioConcat(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="AudioConcat", + search_aliases=["join audio", "combine audio", "append audio"], display_name="Audio Concat", description="Concatenates the audio1 to audio2 in the specified direction.", category="audio", @@ -467,6 +530,7 @@ class AudioMerge(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="AudioMerge", + search_aliases=["mix audio", "overlay audio", "layer audio"], display_name="Audio Merge", description="Combine two audio tracks by overlaying their waveforms.", category="audio", @@ -527,6 +591,7 @@ class AudioAdjustVolume(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="AudioAdjustVolume", + search_aliases=["audio gain", "loudness", "audio level"], display_name="Audio Adjust Volume", category="audio", inputs=[ @@ -562,6 +627,7 @@ class EmptyAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="EmptyAudio", + search_aliases=["blank audio"], display_name="Empty Audio", category="audio", inputs=[ @@ -616,6 +682,7 @@ class AudioExtension(ComfyExtension): RecordAudio, TrimAudioDuration, SplitAudioChannels, + JoinAudioChannels, AudioConcat, AudioMerge, AudioAdjustVolume, diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 576f3640a..6e0fadca5 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -10,6 +10,7 @@ class Canny(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Canny", + search_aliases=["edge detection", "outline", "contour detection", "line art"], category="image/preprocessors", inputs=[ io.Image.Input("image"), diff --git a/comfy_extras/nodes_color.py b/comfy_extras/nodes_color.py new file mode 100644 index 000000000..80ba121cd --- /dev/null +++ b/comfy_extras/nodes_color.py @@ -0,0 +1,42 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class ColorToRGBInt(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ColorToRGBInt", + display_name="Color to RGB Int", + category="utils", + description="Convert a color to a RGB integer value.", + inputs=[ + io.Color.Input("color"), + ], + outputs=[ + io.Int.Output(display_name="rgb_int"), + ], + ) + + @classmethod + def execute( + cls, + color: str, + ) -> io.NodeOutput: + # expect format #RRGGBB + if len(color) != 7 or color[0] != "#": + raise ValueError("Color must be in format #RRGGBB") + r = int(color[1:3], 16) + g = int(color[3:5], 16) + b = int(color[5:7], 16) + return io.NodeOutput(r * 256 * 256 + g * 256 + b) + + +class ColorExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ColorToRGBInt] + + +async def comfy_entrypoint() -> ColorExtension: + return ColorExtension() diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index e4e4e1cbc..3bc9fccb3 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -109,6 +109,7 @@ class PorterDuffImageComposite(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PorterDuffImageComposite", + search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"], display_name="Porter-Duff Image Composite", category="mask/compositing", inputs=[ @@ -165,6 +166,7 @@ class SplitImageWithAlpha(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SplitImageWithAlpha", + search_aliases=["extract alpha", "separate transparency", "remove alpha"], display_name="Split Image with Alpha", category="mask/compositing", inputs=[ @@ -188,6 +190,7 @@ class JoinImageWithAlpha(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="JoinImageWithAlpha", + search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"], display_name="Join Image with Alpha", category="mask/compositing", inputs=[ diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py index e835feed7..0c1d7f0d4 100644 --- a/comfy_extras/nodes_controlnet.py +++ b/comfy_extras/nodes_controlnet.py @@ -38,6 +38,7 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ControlNetInpaintingAliMamaApply", + search_aliases=["masked controlnet"], category="conditioning/controlnet", inputs=[ io.Conditioning.Input("positive"), diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index f19adf4b9..8afd13acf 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -297,6 +297,7 @@ class ExtendIntermediateSigmas(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ExtendIntermediateSigmas", + search_aliases=["interpolate sigmas"], category="sampling/custom_sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), @@ -700,7 +701,14 @@ class Noise_EmptyNoise: def generate_noise(self, input_latent): latent_image = input_latent["samples"] - return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + if latent_image.is_nested: + tensors = latent_image.unbind() + zeros = [] + for t in tensors: + zeros.append(torch.zeros(t.shape, dtype=t.dtype, layout=t.layout, device="cpu")) + return comfy.nested_tensor.NestedTensor(zeros) + else: + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") class Noise_RandomNoise: @@ -740,7 +748,7 @@ class SamplerCustom(io.ComfyNode): latent = latent_image latent_image = latent["samples"] latent = latent.copy() - latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None)) latent["samples"] = latent_image if not add_noise: @@ -759,6 +767,7 @@ class SamplerCustom(io.ComfyNode): samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples if "x0" in x0_output: x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) @@ -856,6 +865,7 @@ class DualCFGGuider(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="DualCFGGuider", + search_aliases=["dual prompt guidance"], category="sampling/custom_sampling/guiders", inputs=[ io.Model.Input("model"), @@ -883,6 +893,7 @@ class DisableNoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="DisableNoise", + search_aliases=["zero noise"], category="sampling/custom_sampling/noise", inputs=[], outputs=[io.Noise.Output()] @@ -936,7 +947,7 @@ class SamplerCustomAdvanced(io.ComfyNode): latent = latent_image latent_image = latent["samples"] latent = latent.copy() - latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None)) latent["samples"] = latent_image noise_mask = None @@ -951,6 +962,7 @@ class SamplerCustomAdvanced(io.ComfyNode): samples = samples.to(comfy.model_management.intermediate_device()) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples if "x0" in x0_output: x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) @@ -1019,6 +1031,7 @@ class ManualSigmas(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ManualSigmas", + search_aliases=["custom noise schedule", "define sigmas"], category="_for_testing/custom_sampling", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 5ef851bd0..fb9409ac3 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1223,11 +1223,11 @@ class ResolutionBucket(io.ComfyNode): class MakeTrainingDataset(io.ComfyNode): """Encode images with VAE and texts with CLIP to create a training dataset.""" - @classmethod def define_schema(cls): return io.Schema( node_id="MakeTrainingDataset", + search_aliases=["encode dataset"], display_name="Make Training Dataset", category="dataset", is_experimental=True, @@ -1309,11 +1309,11 @@ class MakeTrainingDataset(io.ComfyNode): class SaveTrainingDataset(io.ComfyNode): """Save encoded training dataset (latents + conditioning) to disk.""" - @classmethod def define_schema(cls): return io.Schema( node_id="SaveTrainingDataset", + search_aliases=["export training data"], display_name="Save Training Dataset", category="dataset", is_experimental=True, @@ -1410,11 +1410,11 @@ class SaveTrainingDataset(io.ComfyNode): class LoadTrainingDataset(io.ComfyNode): """Load encoded training dataset from disk.""" - @classmethod def define_schema(cls): return io.Schema( node_id="LoadTrainingDataset", + search_aliases=["import dataset", "training data"], display_name="Load Training Dataset", category="dataset", is_experimental=True, diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 6dfdf466c..34ffb9a89 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -11,6 +11,7 @@ class DifferentialDiffusion(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="DifferentialDiffusion", + search_aliases=["inpaint gradient", "variable denoise strength"], display_name="Differential Diffusion", category="_for_testing", inputs=[ diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 11b23ffdb..90d730df6 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -29,8 +29,10 @@ def easycache_forward_wrapper(executor, *args, **kwargs): do_easycache = easycache.should_do_easycache(sigmas) if do_easycache: easycache.check_metadata(x) + # if there isn't a cache diff for current conds, we cannot skip this step + can_apply_cache_diff = easycache.can_apply_cache_diff(uuids) # if first cond marked this step for skipping, skip it and use appropriate cached values - if easycache.skip_current_step: + if easycache.skip_current_step and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") return easycache.apply_cache_diff(x, uuids) @@ -44,7 +46,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm easycache.cumulative_change_rate += approx_output_change_rate - if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") # other conds should also skip this step, and instead use their cached values @@ -240,6 +242,9 @@ class EasyCacheHolder: return to_return.clone() return to_return + def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: + return all(uuid in self.uuid_cache_diffs for uuid in uuids) + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): if self.first_cond_uuid in uuids: self.total_steps_skipped += 1 diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index f308eb0c1..3d590af4b 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -58,6 +58,7 @@ class FreSca(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FreSca", + search_aliases=["frequency guidance"], display_name="FreSca", category="_for_testing", description="Applies frequency-dependent scaling to the guidance", diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py index eee683ee1..e345fe51d 100644 --- a/comfy_extras/nodes_hidream.py +++ b/comfy_extras/nodes_hidream.py @@ -38,6 +38,7 @@ class CLIPTextEncodeHiDream(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeHiDream", + search_aliases=["hidream prompt"], category="advanced/conditioning", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 1edc06f3d..58e511ef5 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -259,6 +259,7 @@ class SetClipHooks: return (clip,) class ConditioningTimestepsRange: + SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"] NodeId = 'ConditioningTimestepsRange' NodeName = 'Timesteps Range' @classmethod @@ -468,6 +469,7 @@ class SetHookKeyframes: return (hooks,) class CreateHookKeyframe: + SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"] NodeId = 'CreateHookKeyframe' NodeName = 'Create Hook Keyframe' @classmethod @@ -497,6 +499,7 @@ class CreateHookKeyframe: return (prev_hook_kf,) class CreateHookKeyframesInterpolated: + SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"] NodeId = 'CreateHookKeyframesInterpolated' NodeName = 'Create Hook Keyframes Interp.' @classmethod @@ -544,6 +547,7 @@ class CreateHookKeyframesInterpolated: return (prev_hook_kf,) class CreateHookKeyframesFromFloats: + SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"] NodeId = 'CreateHookKeyframesFromFloats' NodeName = 'Create Hook Keyframes From Floats' @classmethod @@ -618,6 +622,7 @@ class SetModelHooksOnCond: # Combine Hooks #------------------------------------------ class CombineHooks: + SEARCH_ALIASES = ["merge hooks"] NodeId = 'CombineHooks2' NodeName = 'Combine Hooks [2]' @classmethod diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 32be182f1..774da75a3 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -5,7 +5,9 @@ import comfy.model_management from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler import folder_paths +import json class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -54,7 +56,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): @classmethod def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples":latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8}) generate = execute # TODO: remove @@ -71,7 +73,7 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: # Using scale factor of 16 instead of 8 latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples": latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16}) class HunyuanVideo15ImageToVideo(io.ComfyNode): @@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode): @classmethod def execute(cls, model_name) -> io.NodeOutput: model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path, safe_load=True) + sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True) if "blocks.0.block.0.conv.weight" in sd: config = { @@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode): "global_residual": False, } model_type = "720p" + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) elif "up.0.block.0.conv1.conv.weight" in sd: sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} config = { @@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode): "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), } model_type = "1080p" - - model = HunyuanVideo15SRModel(model_type, config) - model.load_sd(sd) + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + elif "post_upsample_res_blocks.0.conv2.bias" in sd: + config = json.loads(metadata["config"]) + model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32])) + model.load_state_dict(sd) return io.NodeOutput(model) diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index adca14f62..5bb5df48e 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -618,6 +618,7 @@ class SaveGLB(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveGLB", + search_aliases=["export 3d model", "save mesh"], category="3d", is_output_node=True, inputs=[ diff --git a/comfy_extras/nodes_image_compare.py b/comfy_extras/nodes_image_compare.py new file mode 100644 index 000000000..8e9f809e6 --- /dev/null +++ b/comfy_extras/nodes_image_compare.py @@ -0,0 +1,53 @@ +import nodes + +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension + + +class ImageCompare(IO.ComfyNode): + """Compares two images with a slider interface.""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageCompare", + display_name="Image Compare", + description="Compares two images side by side with a slider.", + category="image", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.Image.Input("image_a", optional=True), + IO.Image.Input("image_b", optional=True), + IO.ImageCompare.Input("compare_view"), + ], + outputs=[], + ) + + @classmethod + def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput: + result = {"a_images": [], "b_images": []} + + preview_node = nodes.PreviewImage() + + if image_a is not None and len(image_a) > 0: + saved = preview_node.save_images(image_a, "comfy.compare.a") + result["a_images"] = saved["ui"]["images"] + + if image_b is not None and len(image_b) > 0: + saved = preview_node.save_images(image_b, "comfy.compare.b") + result["b_images"] = saved["ui"]["images"] + + return IO.NodeOutput(ui=result) + + +class ImageCompareExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCompare, + ] + + +async def comfy_entrypoint() -> ImageCompareExtension: + return ImageCompareExtension() diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index ce21caade..cb4fb24a1 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -22,6 +22,7 @@ class ImageCrop(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCrop", + search_aliases=["trim"], display_name="Image Crop", category="image/transform", inputs=[ @@ -51,6 +52,7 @@ class RepeatImageBatch(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="RepeatImageBatch", + search_aliases=["duplicate image", "clone image"], category="image/batch", inputs=[ IO.Image.Input("image"), @@ -72,6 +74,7 @@ class ImageFromBatch(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageFromBatch", + search_aliases=["select image", "pick from batch", "extract image"], category="image/batch", inputs=[ IO.Image.Input("image"), @@ -97,6 +100,7 @@ class ImageAddNoise(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageAddNoise", + search_aliases=["film grain"], category="image", inputs=[ IO.Image.Input("image"), @@ -194,11 +198,11 @@ class SaveAnimatedPNG(IO.ComfyNode): class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" - @classmethod def define_schema(cls): return IO.Schema( node_id="ImageStitch", + search_aliases=["combine images", "join images", "concatenate images", "side by side"], display_name="Image Stitch", description="Stitches image2 to image1 in the specified direction.\n" "If image2 is not provided, returns image1 unchanged.\n" @@ -369,11 +373,11 @@ class ImageStitch(IO.ComfyNode): class ResizeAndPadImage(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="ResizeAndPadImage", + search_aliases=["fit to size"], category="image/transform", inputs=[ IO.Image.Input("image"), @@ -420,11 +424,11 @@ class ResizeAndPadImage(IO.ComfyNode): class SaveSVGNode(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="SaveSVGNode", + search_aliases=["export vector", "save vector graphics"], description="Save SVG files on disk.", category="image/save", inputs=[ @@ -492,11 +496,11 @@ class SaveSVGNode(IO.ComfyNode): class GetImageSize(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="GetImageSize", + search_aliases=["dimensions", "resolution", "image info"], display_name="Get Image Size", description="Returns width and height of the image, and passes it through unchanged.", category="image", @@ -527,11 +531,11 @@ class GetImageSize(IO.ComfyNode): class ImageRotate(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="ImageRotate", + search_aliases=["turn", "flip orientation"], category="image/transform", inputs=[ IO.Image.Input("image"), @@ -557,11 +561,11 @@ class ImageRotate(IO.ComfyNode): class ImageFlip(IO.ComfyNode): - @classmethod def define_schema(cls): return IO.Schema( node_id="ImageFlip", + search_aliases=["mirror", "reflect"], category="image/transform", inputs=[ IO.Image.Input("image"), diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 9cb234be1..346c50cde 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -104,6 +104,7 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeKandinsky5", + search_aliases=["kandinsky prompt"], category="advanced/conditioning/kandinsky5", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 9ba1c4ba8..6aecf1561 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -21,6 +21,7 @@ class LatentAdd(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentAdd", + search_aliases=["combine latents", "sum latents"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -47,6 +48,7 @@ class LatentSubtract(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentSubtract", + search_aliases=["difference latent", "remove features"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -73,6 +75,7 @@ class LatentMultiply(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentMultiply", + search_aliases=["scale latent", "amplify latent", "latent gain"], category="latent/advanced", inputs=[ io.Latent.Input("samples"), @@ -96,6 +99,7 @@ class LatentInterpolate(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentInterpolate", + search_aliases=["blend latent", "mix latent", "lerp latent", "transition"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -134,6 +138,7 @@ class LatentConcat(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentConcat", + search_aliases=["join latents", "stitch latents"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -173,6 +178,7 @@ class LatentCut(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentCut", + search_aliases=["crop latent", "slice latent", "extract region"], category="latent/advanced", inputs=[ io.Latent.Input("samples"), @@ -213,6 +219,7 @@ class LatentCutToBatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentCutToBatch", + search_aliases=["slice to batch", "split latent", "tile latent"], category="latent/advanced", inputs=[ io.Latent.Input("samples"), @@ -254,6 +261,7 @@ class LatentBatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentBatch", + search_aliases=["combine latents", "merge latents", "join latents"], category="latent/batch", is_deprecated=True, inputs=[ @@ -310,6 +318,7 @@ class LatentApplyOperation(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentApplyOperation", + search_aliases=["transform latent"], category="latent/advanced/operations", is_experimental=True, inputs=[ @@ -365,6 +374,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentOperationTonemapReinhard", + search_aliases=["hdr latent"], category="latent/advanced/operations", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 545588ef8..4b8d950ae 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -24,7 +24,7 @@ class Load3D(IO.ComfyNode): files = [ normalize_path(str(file_path.relative_to(base_path))) for file_path in input_path.rglob("*") - if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} + if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'} ] return IO.Schema( node_id="Load3D", @@ -75,6 +75,7 @@ class Preview3D(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="Preview3D", + search_aliases=["view mesh", "3d viewer"], display_name="Preview 3D & Animation", category="3d", is_experimental=True, diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index eb888316a..c066064ac 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -104,19 +104,23 @@ class CustomComboNode(io.ComfyNode): category="utils", is_experimental=True, inputs=[io.Combo.Input("choice", options=[])], - outputs=[io.String.Output()] + outputs=[ + io.String.Output(display_name="STRING"), + io.Int.Output(display_name="INDEX"), + ], + accept_all_inputs=True, ) @classmethod - def validate_inputs(cls, choice: io.Combo.Type) -> bool: + def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool: # NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs. # I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined. # I need to skip checking that the chosen combo option is in the options list, since those are defined by the user. return True @classmethod - def execute(cls, choice: io.Combo.Type) -> io.NodeOutput: - return io.NodeOutput(choice) + def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput: + return io.NodeOutput(choice, index) class DCTestNode(io.ComfyNode): @@ -224,6 +228,7 @@ class ConvertStringToComboNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ConvertStringToComboNode", + search_aliases=["string to dropdown", "text to combo"], display_name="Convert String to Combo", category="logic", inputs=[io.String.Input("string")], @@ -239,6 +244,7 @@ class InvertBooleanNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="InvertBooleanNode", + search_aliases=["not", "toggle", "negate", "flip boolean"], display_name="Invert Boolean", category="logic", inputs=[io.Boolean.Input("boolean")], diff --git a/comfy_extras/nodes_lora_debug.py b/comfy_extras/nodes_lora_debug.py new file mode 100644 index 000000000..937a0fbfb --- /dev/null +++ b/comfy_extras/nodes_lora_debug.py @@ -0,0 +1,79 @@ +import folder_paths +import comfy.utils +import comfy.sd + + +class LoraLoaderBypass: + """ + Apply LoRA in bypass mode without modifying base model weights. + + Bypass mode computes: output = base_forward(x) + lora_path(x) + This is useful for training and when model weights are offloaded. + """ + + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP") + OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") + FUNCTION = "load_lora" + + CATEGORY = "loaders" + DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." + EXPERIMENTAL = True + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + self.loaded_lora = None + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip) + return (model_lora, clip_lora) + + +class LoraLoaderBypassModelOnly(LoraLoaderBypass): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + + +NODE_CLASS_MAPPINGS = { + "LoraLoaderBypass": LoraLoaderBypass, + "LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "LoraLoaderBypass": "Load LoRA (Bypass) (For debugging)", + "LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only) (for debugging)", +} diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py index a2375cba7..fb89e03f4 100644 --- a/comfy_extras/nodes_lora_extract.py +++ b/comfy_extras/nodes_lora_extract.py @@ -78,6 +78,7 @@ class LoraSave(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LoraSave", + search_aliases=["export lora"], display_name="Extract and Save Lora", category="_for_testing", inputs=[ diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 50da5f4eb..2aec62f61 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode): generate = execute # TODO: remove +class LTXVImgToVideoInplace(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVImgToVideoInplace", + category="conditioning/video_models", + inputs=[ + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Latent.Input("latent"), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0), + io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.") + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput: + if bypass: + return (latent,) + + samples = latent["samples"] + _, height_scale_factor, width_scale_factor = ( + vae.downscale_index_formula + ) + + batch, _, latent_frames, latent_height, latent_width = samples.shape + width = latent_width * width_scale_factor + height = latent_height * height_scale_factor + + if image.shape[1] != height or image.shape[2] != width: + pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + else: + pixels = image + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + + samples[:, :, :t.shape[2]] = t + + conditioning_latent_frames_mask = torch.ones( + (batch, 1, latent_frames, 1, 1), + dtype=torch.float32, + device=samples.device, + ) + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength + + return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask}) + + generate = execute # TODO: remove + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -106,12 +159,12 @@ def get_keyframe_idxs(cond): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 - num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] + # keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start + num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0] return keyframe_idxs, num_keyframes class LTXVAddGuide(io.ComfyNode): - NUM_PREFIX_FRAMES = 2 - PATCHIFIER = SymmetricPatchifier(1) + PATCHIFIER = SymmetricPatchifier(1, start_end=True) @classmethod def define_schema(cls): @@ -170,11 +223,24 @@ class LTXVAddGuide(io.ComfyNode): return frame_idx, latent_idx @classmethod - def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1): keyframe_idxs, _ = get_keyframe_idxs(cond) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 pixel_coords[:, 0] += frame_idx + + # The following adjusts keyframe end positions for small grid IC-LoRA. + # After dilation, the small grid has the same size and position as the large grid, + # but each token encodes a larger image patch. We adjust the end position (not start) + # so that RoPE represents the correct middle point of each token. + # keyframe_idxs dims: (batch, spatial_dim [t,h,w], token_id, [start, end]) + # We only adjust h,w (not t) in dim 1, and only end (not start) in dim 3. + spatial_end_offset = (latent_downscale_factor - 1) * torch.tensor( + scale_factors[1:], + device=pixel_coords.device, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset.to(pixel_coords.dtype) + if keyframe_idxs is None: keyframe_idxs = pixel_coords else: @@ -182,26 +248,35 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): - _, latent_idx = cls.get_latent_index( - cond=positive, - latent_length=latent_image.shape[2], - guide_length=guiding_latent.shape[2], - frame_idx=frame_idx, - scale_factors=scale_factors, - ) - noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1): + if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: + raise ValueError("Adding guide to a combined AV latent is not supported.") - positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) - negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor) - mask = torch.full( - (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), - 1.0 - strength, - dtype=noise_mask.dtype, - device=noise_mask.device, - ) + if guide_mask is not None: + target_h = max(noise_mask.shape[3], guide_mask.shape[3]) + target_w = max(noise_mask.shape[4], guide_mask.shape[4]) + if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1: + noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w) + + if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1: + guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w) + mask = guide_mask - strength + else: + mask = torch.full( + (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + # This solves audio video combined latent case where latent_image has audio latent concatenated + # in channel dimension with video latent. The solution is to pad guiding latent accordingly. + if latent_image.shape[1] > guiding_latent.shape[1]: + pad_len = latent_image.shape[1] - guiding_latent.shape[1] + guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0) latent_image = torch.cat([latent_image, guiding_latent], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask @@ -238,33 +313,17 @@ class LTXVAddGuide(io.ComfyNode): frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) - positive, negative, latent_image, noise_mask = cls.append_keyframe( positive, negative, frame_idx, latent_image, noise_mask, - t[:, :, :num_prefix_frames], + t, strength, scale_factors, ) - latent_idx += num_prefix_frames - - t = t[:, :, num_prefix_frames:] - if t.shape[2] == 0: - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) - - latent_image, noise_mask = cls.replace_latent_frames( - latent_image, - noise_mask, - t, - latent_idx, - strength, - ) - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) generate = execute # TODO: remove @@ -507,18 +566,90 @@ class LTXVPreprocess(io.ComfyNode): preprocess = execute # TODO: remove + +import comfy.nested_tensor +class LTXVConcatAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVConcatAVLatent", + category="latent/video/ltxv", + inputs=[ + io.Latent.Input("video_latent"), + io.Latent.Input("audio_latent"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, video_latent, audio_latent) -> io.NodeOutput: + output = {} + output.update(video_latent) + output.update(audio_latent) + video_noise_mask = video_latent.get("noise_mask", None) + audio_noise_mask = audio_latent.get("noise_mask", None) + + if video_noise_mask is not None or audio_noise_mask is not None: + if video_noise_mask is None: + video_noise_mask = torch.ones_like(video_latent["samples"]) + if audio_noise_mask is None: + audio_noise_mask = torch.ones_like(audio_latent["samples"]) + output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask)) + + output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"])) + + return io.NodeOutput(output) + + +class LTXVSeparateAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVSeparateAVLatent", + category="latent/video/ltxv", + description="LTXV Separate AV Latent", + inputs=[ + io.Latent.Input("av_latent"), + ], + outputs=[ + io.Latent.Output(display_name="video_latent"), + io.Latent.Output(display_name="audio_latent"), + ], + ) + + @classmethod + def execute(cls, av_latent) -> io.NodeOutput: + latents = av_latent["samples"].unbind() + video_latent = av_latent.copy() + video_latent["samples"] = latents[0] + audio_latent = av_latent.copy() + audio_latent["samples"] = latents[1] + if "noise_mask" in av_latent: + masks = av_latent["noise_mask"] + if masks is not None: + masks = masks.unbind() + video_latent["noise_mask"] = masks[0] + audio_latent["noise_mask"] = masks[1] + return io.NodeOutput(video_latent, audio_latent) + + class LtxvExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ EmptyLTXVLatentVideo, LTXVImgToVideo, + LTXVImgToVideoInplace, ModelSamplingLTXV, LTXVConditioning, LTXVScheduler, LTXVAddGuide, LTXVPreprocess, LTXVCropGuides, + LTXVConcatAVLatent, + LTXVSeparateAVLatent, ] diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py new file mode 100644 index 000000000..1966fd1bf --- /dev/null +++ b/comfy_extras/nodes_lt_audio.py @@ -0,0 +1,224 @@ +import folder_paths +import comfy.utils +import comfy.model_management +import torch + +from comfy.ldm.lightricks.vae.audio_vae import AudioVAE +from comfy_api.latest import ComfyExtension, io + + +class LTXVAudioVAELoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAELoader", + display_name="LTXV Audio VAE Loader", + category="audio", + inputs=[ + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + tooltip="Audio VAE checkpoint to load.", + ) + ], + outputs=[io.Vae.Output(display_name="Audio VAE")], + ) + + @classmethod + def execute(cls, ckpt_name: str) -> io.NodeOutput: + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) + return io.NodeOutput(AudioVAE(sd, metadata)) + + +class LTXVAudioVAEEncode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEEncode", + display_name="LTXV Audio VAE Encode", + category="audio", + inputs=[ + io.Audio.Input("audio", tooltip="The audio to be encoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to use for encoding.", + ), + ], + outputs=[io.Latent.Output(display_name="Audio Latent")], + ) + + @classmethod + def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latents = audio_vae.encode(audio) + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": int(audio_vae.sample_rate), + "type": "audio", + } + ) + + +class LTXVAudioVAEDecode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEDecode", + display_name="LTXV Audio VAE Decode", + category="audio", + inputs=[ + io.Latent.Input("samples", tooltip="The latent to be decoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model used for decoding the latent.", + ), + ], + outputs=[io.Audio.Output(display_name="Audio")], + ) + + @classmethod + def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latent = samples["samples"] + if audio_latent.is_nested: + audio_latent = audio_latent.unbind()[-1] + audio = audio_vae.decode(audio_latent).to(audio_latent.device) + output_audio_sample_rate = audio_vae.output_sample_rate + return io.NodeOutput( + { + "waveform": audio, + "sample_rate": int(output_audio_sample_rate), + } + ) + + +class LTXVEmptyLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVEmptyLatentAudio", + display_name="LTXV Empty Latent Audio", + category="latent/audio", + inputs=[ + io.Int.Input( + "frames_number", + default=97, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames.", + ), + io.Int.Input( + "frame_rate", + default=25, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames per second.", + ), + io.Int.Input( + "batch_size", + default=1, + min=1, + max=4096, + display_mode=io.NumberDisplay.number, + tooltip="The number of latent audio samples in the batch.", + ), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to get configuration from.", + ), + ], + outputs=[io.Latent.Output(display_name="Latent")], + ) + + @classmethod + def execute( + cls, + frames_number: int, + frame_rate: int, + batch_size: int, + audio_vae: AudioVAE, + ) -> io.NodeOutput: + """Generate empty audio latents matching the reference pipeline structure.""" + + assert audio_vae is not None, "Audio VAE model is required" + + z_channels = audio_vae.latent_channels + audio_freq = audio_vae.latent_frequency_bins + sampling_rate = int(audio_vae.sample_rate) + + num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + + audio_latents = torch.zeros( + (batch_size, z_channels, num_audio_latents, audio_freq), + device=comfy.model_management.intermediate_device(), + ) + + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": sampling_rate, + "type": "audio", + } + ) + + +class LTXAVTextEncoderLoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXAVTextEncoderLoader", + display_name="LTXV Audio Text Encoder Loader", + category="advanced/loaders", + description="[Recipes]\n\nltxav: gemma 3 12B", + inputs=[ + io.Combo.Input( + "text_encoder", + options=folder_paths.get_filename_list("text_encoders"), + ), + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + ), + io.Combo.Input( + "device", + options=["default", "cpu"], + ) + ], + outputs=[io.Clip.Output()], + ) + + @classmethod + def execute(cls, text_encoder, ckpt_name, device="default"): + clip_type = comfy.sd.CLIPType.LTXV + + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder) + clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + model_options = {} + if device == "cpu": + model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) + return io.NodeOutput(clip) + + +class LTXVAudioExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LTXVAudioVAELoader, + LTXVAudioVAEEncode, + LTXVAudioVAEDecode, + LTXVEmptyLatentAudio, + LTXAVTextEncoderLoader, + ] + + +async def comfy_entrypoint() -> ComfyExtension: + return LTXVAudioExtension() diff --git a/comfy_extras/nodes_lt_upsampler.py b/comfy_extras/nodes_lt_upsampler.py new file mode 100644 index 000000000..f99ba13fb --- /dev/null +++ b/comfy_extras/nodes_lt_upsampler.py @@ -0,0 +1,75 @@ +from comfy import model_management +import math + +class LTXVLatentUpsampler: + """ + Upsamples a video latent by a factor of 2. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "upscale_model": ("LATENT_UPSCALE_MODEL",), + "vae": ("VAE",), + } + } + + RETURN_TYPES = ("LATENT",) + FUNCTION = "upsample_latent" + CATEGORY = "latent/video" + EXPERIMENTAL = True + + def upsample_latent( + self, + samples: dict, + upscale_model, + vae, + ) -> tuple: + """ + Upsample the input latent using the provided model. + + Args: + samples (dict): Input latent samples + upscale_model (LatentUpsampler): Loaded upscale model + vae: VAE model for normalization + auto_tiling (bool): Whether to automatically tile the input for processing + + Returns: + tuple: Tuple containing the upsampled latent + """ + device = model_management.get_torch_device() + memory_required = model_management.module_size(upscale_model) + + model_dtype = next(upscale_model.parameters()).dtype + latents = samples["samples"] + input_dtype = latents.dtype + + memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate + model_management.free_memory(memory_required, device) + + try: + upscale_model.to(device) # TODO: use the comfy model management system. + + latents = latents.to(dtype=model_dtype, device=device) + + """Upsample latents without tiling.""" + latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents) + upsampled_latents = upscale_model(latents) + finally: + upscale_model.cpu() + + upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize( + upsampled_latents + ) + upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device()) + return_dict = samples.copy() + return_dict["samples"] = upsampled_latents + return_dict.pop("noise_mask", None) + return (return_dict,) + + +NODE_CLASS_MAPPINGS = { + "LTXVLatentUpsampler": LTXVLatentUpsampler, +} diff --git a/comfy_extras/nodes_lumina2.py b/comfy_extras/nodes_lumina2.py index 89ff2397a..2550475ae 100644 --- a/comfy_extras/nodes_lumina2.py +++ b/comfy_extras/nodes_lumina2.py @@ -79,6 +79,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeLumina2", + search_aliases=["lumina prompt"], display_name="CLIP Text Encode for Lumina2", category="conditioning", description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 290e6f55e..98e8fef8f 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -50,6 +50,7 @@ class LatentCompositeMasked(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="LatentCompositeMasked", + search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"], category="latent", inputs=[ IO.Latent.Input("destination"), @@ -78,6 +79,7 @@ class ImageCompositeMasked(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCompositeMasked", + search_aliases=["paste image", "overlay", "layer"], category="image", inputs=[ IO.Image.Input("destination"), @@ -105,6 +107,7 @@ class MaskToImage(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskToImage", + search_aliases=["convert mask"], display_name="Convert Mask to Image", category="mask", inputs=[ @@ -126,6 +129,7 @@ class ImageToMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageToMask", + search_aliases=["extract channel", "channel to mask"], display_name="Convert Image to Mask", category="mask", inputs=[ @@ -149,6 +153,7 @@ class ImageColorToMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageColorToMask", + search_aliases=["color keying", "chroma key"], category="mask", inputs=[ IO.Image.Input("image"), @@ -194,6 +199,7 @@ class InvertMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="InvertMask", + search_aliases=["reverse mask", "flip mask"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -214,6 +220,7 @@ class CropMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="CropMask", + search_aliases=["cut mask", "extract mask region", "mask slice"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -239,6 +246,7 @@ class MaskComposite(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskComposite", + search_aliases=["combine masks", "blend masks", "layer masks"], category="mask", inputs=[ IO.Mask.Input("destination"), @@ -287,6 +295,7 @@ class FeatherMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="FeatherMask", + search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -333,6 +342,7 @@ class GrowMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="GrowMask", + search_aliases=["expand mask", "shrink mask"], display_name="Grow Mask", category="mask", inputs=[ @@ -370,6 +380,7 @@ class ThresholdMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ThresholdMask", + search_aliases=["binary mask"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -394,6 +405,7 @@ class MaskPreview(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskPreview", + search_aliases=["show mask", "view mask", "inspect mask", "debug mask"], display_name="Preview Mask", category="mask", description="Saves the input images to your ComfyUI output directory.", diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index ae5d2c563..f22b333fc 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -299,6 +299,7 @@ class RescaleCFG: return (m, ) class ModelComputeDtype: + SEARCH_ALIASES = ["model precision", "change dtype"] @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index f20beab7d..5384ed531 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -91,6 +91,7 @@ class CLIPMergeSimple: class CLIPSubtract: + SEARCH_ALIASES = ["clip difference", "text encoder subtract"] @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",), @@ -113,6 +114,7 @@ class CLIPSubtract: class CLIPAdd: + SEARCH_ALIASES = ["combine clip"] @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",), @@ -225,6 +227,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) class CheckpointSave: + SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"] def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -337,6 +340,7 @@ class VAESave: return {} class ModelSave: + SEARCH_ALIASES = ["export model", "checkpoint save"] def __init__(self): self.output_dir = folder_paths.get_output_directory() diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 1355b3c93..176e6bc2f 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,6 +7,7 @@ import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel class BlockWiseControlBlock(torch.nn.Module): @@ -244,6 +245,10 @@ class ModelPatchLoader: elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet sd = z_image_convert(sd) config = {} + if 'control_layers.4.adaLN_modulation.0.weight' not in sd: + config['n_control_layers'] = 3 + config['additional_in_dim'] = 17 + config['refiner_control'] = True if 'control_layers.14.adaLN_modulation.0.weight' in sd: config['n_control_layers'] = 15 config['additional_in_dim'] = 17 @@ -253,10 +258,18 @@ class ModelPatchLoader: if torch.count_nonzero(ref_weight) == 0: config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) + elif "audio_proj.proj1.weight" in sd: + model = MultiTalkModelPatch( + audio_window=5, context_tokens=32, vae_scale=4, + in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0], + intermediate_dim=sd["audio_proj.proj1.weight"].shape[0], + out_dim=sd["audio_proj.norm.weight"].shape[0], + device=comfy.model_management.unet_offload_device(), + operations=comfy.ops.manual_cast) - model.load_state_dict(sd) - model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) - return (model,) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + model.load_state_dict(sd, assign=model_patcher.is_dynamic()) + return (model_patcher,) class DiffSynthCnetPatch: @@ -520,6 +533,38 @@ class USOStyleReference: return (model_patched,) +class MultiTalkModelPatch(torch.nn.Module): + def __init__( + self, + audio_window: int = 5, + intermediate_dim: int = 512, + in_dim: int = 5120, + out_dim: int = 768, + context_tokens: int = 32, + vae_scale: int = 4, + num_layers: int = 40, + + device=None, dtype=None, operations=None + ): + super().__init__() + self.audio_proj = MultiTalkAudioProjModel( + seq_len=audio_window, + seq_len_vf=audio_window+vae_scale-1, + intermediate_dim=intermediate_dim, + out_dim=out_dim, + context_tokens=context_tokens, + device=device, + dtype=dtype, + operations=operations + ) + self.blocks = torch.nn.ModuleList( + [ + WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index 67377e1bc..4ab2fb7e8 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -12,6 +12,7 @@ class Morphology(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Morphology", + search_aliases=["erode", "dilate"], display_name="ImageMorphology", category="image/postprocessing", inputs=[ @@ -57,6 +58,7 @@ class ImageRGBToYUV(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageRGBToYUV", + search_aliases=["color space conversion"], category="image/batch", inputs=[ io.Image.Input("image"), @@ -78,6 +80,7 @@ class ImageYUVToRGB(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageYUVToRGB", + search_aliases=["color space conversion"], category="image/batch", inputs=[ io.Image.Input("Y"), diff --git a/comfy_extras/nodes_pixart.py b/comfy_extras/nodes_pixart.py index a23e87b1f..2f1b73e60 100644 --- a/comfy_extras/nodes_pixart.py +++ b/comfy_extras/nodes_pixart.py @@ -7,6 +7,7 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodePixArtAlpha", + search_aliases=["pixart prompt"], category="advanced/conditioning", description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.", inputs=[ diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 01afa13a1..a52a90e2c 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -254,6 +254,7 @@ class ResizeType(str, Enum): SCALE_HEIGHT = "scale height" SCALE_TOTAL_PIXELS = "scale total pixels" MATCH_SIZE = "match size" + SCALE_TO_MULTIPLE = "scale to multiple" def is_image(input: torch.Tensor) -> bool: # images have 4 dimensions: [batch, height, width, channels] @@ -328,7 +329,7 @@ def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method if height < width: width = round((width / height) * shorter_size) height = shorter_size - elif width > height: + elif width < height: height = round((height / width) * shorter_size) width = shorter_size else: @@ -363,8 +364,44 @@ def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str input = finalize_image_mask_input(input, is_type_image) return input -class ResizeImageMaskNode(io.ComfyNode): +def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor: + if multiple <= 1: + return input + is_type_image = is_image(input) + if is_type_image: + _, height, width, _ = input.shape + else: + _, height, width = input.shape + target_w = (width // multiple) * multiple + target_h = (height // multiple) * multiple + if target_w == 0 or target_h == 0: + return input + if target_w == width and target_h == height: + return input + s_w = target_w / width + s_h = target_h / height + if s_w >= s_h: + scaled_w = target_w + scaled_h = int(math.ceil(height * s_w)) + if scaled_h < target_h: + scaled_h = target_h + else: + scaled_h = target_h + scaled_w = int(math.ceil(width * s_h)) + if scaled_w < target_w: + scaled_w = target_w + input = init_image_mask_input(input, is_type_image) + input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + x0 = (scaled_w - target_w) // 2 + y0 = (scaled_h - target_h) // 2 + x1 = x0 + target_w + y1 = y0 + target_h + if is_type_image: + return input[:, y0:y1, x0:x1, :] + return input[:, y0:y1, x0:x1] +class ResizeImageMaskNode(io.ComfyNode): scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] crop_methods = ["disabled", "center"] @@ -378,47 +415,67 @@ class ResizeImageMaskNode(io.ComfyNode): longer_size: int shorter_size: int megapixels: float + multiple: int @classmethod def define_schema(cls): template = io.MatchType.Template("input_type", [io.Image, io.Mask]) - crop_combo = io.Combo.Input("crop", options=cls.crop_methods, default="center") + crop_combo = io.Combo.Input( + "crop", + options=cls.crop_methods, + default="center", + tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.", + ) return io.Schema( node_id="ResizeImageMaskNode", display_name="Resize Image/Mask", + description="Resize an image or mask using various scaling methods.", category="transform", + search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"], inputs=[ io.MatchType.Input("input", template=template), - io.DynamicCombo.Input("resize_type", options=[ - io.DynamicCombo.Option(ResizeType.SCALE_BY, [ - io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01), + io.DynamicCombo.Input( + "resize_type", + tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.", + options=[ + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."), + crop_combo, ]), - io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ - io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), - io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), - crop_combo, + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ - io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ - io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ - io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ - io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."), ]), - io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ - io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."), ]), - io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ - io.MultiType.Input("match", [io.Image, io.Mask]), - crop_combo, + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."), + crop_combo, ]), - ]), - io.Combo.Input("scale_method", options=cls.scale_methods, default="area"), + io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ + io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."), + ]), + ], + ), + io.Combo.Input( + "scale_method", + options=cls.scale_methods, + default="area", + tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.", + ), ], outputs=[io.MatchType.Output(template=template, display_name="resized")] ) @@ -442,6 +499,8 @@ class ResizeImageMaskNode(io.ComfyNode): return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) elif selected_type == ResizeType.MATCH_SIZE: return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_TO_MULTIPLE: + return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method)) raise ValueError(f"Unsupported resize type: {selected_type}") def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: @@ -506,6 +565,7 @@ class BatchImagesNode(io.ComfyNode): node_id="BatchImagesNode", display_name="Batch Images", category="image", + search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], inputs=[ io.Autogrow.Input("images", template=autogrow_template) ], @@ -524,6 +584,7 @@ class BatchMasksNode(io.ComfyNode): autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50) return io.Schema( node_id="BatchMasksNode", + search_aliases=["combine masks", "stack masks", "merge masks"], display_name="Batch Masks", category="mask", inputs=[ @@ -544,6 +605,7 @@ class BatchLatentsNode(io.ComfyNode): autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50) return io.Schema( node_id="BatchLatentsNode", + search_aliases=["combine latents", "stack latents", "merge latents"], display_name="Batch Latents", category="latent", inputs=[ @@ -567,6 +629,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode): prefix="input", min=1, max=50) return io.Schema( node_id="BatchImagesMasksLatentsNode", + search_aliases=["combine batch", "merge batch", "stack inputs"], display_name="Batch Images/Masks/Latents", category="util", inputs=[ diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 139b07c93..b0a6f279d 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -16,6 +16,7 @@ class PreviewAny(): OUTPUT_NODE = True CATEGORY = "utils" + SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"] def main(self, source=None): value = 'None' diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 14782cb2b..736213a47 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode): @classmethod def execute(cls, width, height, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples":latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8}) generate = execute # TODO: remove @@ -65,6 +65,7 @@ class CLIPTextEncodeSD3(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeSD3", + search_aliases=["sd3 prompt"], category="advanced/conditioning", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 571d89f62..8d3e65cc5 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode): node_id="StringConcatenate", display_name="Concatenate", category="utils/string", + search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), @@ -31,6 +32,7 @@ class StringSubstring(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringSubstring", + search_aliases=["extract text", "text portion"], display_name="Substring", category="utils/string", inputs=[ @@ -53,6 +55,7 @@ class StringLength(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringLength", + search_aliases=["character count", "text size"], display_name="Length", category="utils/string", inputs=[ @@ -73,6 +76,7 @@ class CaseConverter(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CaseConverter", + search_aliases=["text case", "uppercase", "lowercase", "capitalize"], display_name="Case Converter", category="utils/string", inputs=[ @@ -105,6 +109,7 @@ class StringTrim(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringTrim", + search_aliases=["clean whitespace", "remove whitespace"], display_name="Trim", category="utils/string", inputs=[ @@ -135,6 +140,7 @@ class StringReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringReplace", + search_aliases=["find and replace", "substitute", "swap text"], display_name="Replace", category="utils/string", inputs=[ @@ -157,6 +163,7 @@ class StringContains(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringContains", + search_aliases=["text includes", "string includes"], display_name="Contains", category="utils/string", inputs=[ @@ -184,6 +191,7 @@ class StringCompare(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringCompare", + search_aliases=["text match", "string equals", "starts with", "ends with"], display_name="Compare", category="utils/string", inputs=[ @@ -219,6 +227,7 @@ class RegexMatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexMatch", + search_aliases=["pattern match", "text contains", "string match"], display_name="Regex Match", category="utils/string", inputs=[ @@ -259,6 +268,7 @@ class RegexExtract(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexExtract", + search_aliases=["pattern extract", "text parser", "parse text"], display_name="Regex Extract", category="utils/string", inputs=[ @@ -333,6 +343,7 @@ class RegexReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexReplace", + search_aliases=["pattern replace", "find and replace", "substitution"], display_name="Regex Replace", category="utils/string", description="Find and replace text using regex patterns.", diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 364804205..024a89391 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -18,6 +18,7 @@ import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers from comfy.weight_adapter import adapters, adapter_maps +from comfy.weight_adapter.bypass import BypassInjectionManager from comfy_api.latest import ComfyExtension, io, ui from comfy.utils import ProgressBar @@ -339,6 +340,11 @@ class TrainSampler(comfy.samplers.Sampler): self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: + for param_groups in self.optimizer.param_groups: + for param in param_groups["params"]: + if param.grad is None: + continue + param.grad.data = param.grad.data.to(param.data.dtype) self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) @@ -498,9 +504,9 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): num_images = sum(t.shape[0] for t in latents) multi_res = False # Not using multi_res path in bucket mode - logging.info(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") + logging.debug(f"Bucket mode: {num_buckets} buckets, {num_images} total samples") for i, lat in enumerate(latents): - logging.info(f" Bucket {i}: shape {lat.shape}") + logging.debug(f" Bucket {i}: shape {lat.shape}") return latents, num_images, multi_res # Non-bucket mode @@ -509,7 +515,7 @@ def _prepare_latents_and_count(latents, dtype, bucket_mode): latents = [t.to(dtype) for t in latents] for latent in latents: all_shapes.add(latent.shape) - logging.info(f"Latent shapes: {all_shapes}") + logging.debug(f"Latent shapes: {all_shapes}") if len(all_shapes) > 1: multi_res = True else: @@ -545,7 +551,7 @@ def _validate_and_expand_conditioning(positive, num_images, bucket_mode): if bucket_mode: return positive # Skip validation in bucket mode - logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") + logging.debug(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: return positive * num_images elif len(positive) != num_images: @@ -596,6 +602,8 @@ def _create_weight_adapter( shape = module.weight.shape lora_params = {} + logging.debug(f"Creating weight adapter for {key} with shape {shape}") + if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) dora_scale = existing_weights.get(f"{key}.dora_scale", None) @@ -690,6 +698,61 @@ def _setup_lora_adapters(mp, existing_weights, algorithm, lora_dtype, rank): return lora_sd, all_weight_adapters +def _setup_lora_adapters_bypass(mp, existing_weights, algorithm, lora_dtype, rank): + """Setup LoRA adapters in bypass mode. + + In bypass mode: + - Weight adapters (lora/lokr/oft) use bypass injection (forward hook) + - Bias/norm adapters (BiasDiff) still use weight wrapper (direct modification) + + This is useful when the base model weights are quantized and cannot be + directly modified. + + Args: + mp: Model patcher + existing_weights: Dict of existing LoRA weights + algorithm: Algorithm name for new adapters + lora_dtype: dtype for LoRA weights + rank: Rank for new LoRA adapters + + Returns: + tuple: (lora_sd dict, all_weight_adapters list, bypass_manager) + """ + lora_sd = {} + all_weight_adapters = [] + bypass_manager = BypassInjectionManager() + + for n, m in mp.model.named_modules(): + if hasattr(m, "weight_function"): + if m.weight is not None: + adapter, params = _create_weight_adapter( + m, n, existing_weights, algorithm, lora_dtype, rank + ) + lora_sd.update(params) + all_weight_adapters.append(adapter) + + key = f"{n}.weight" + # BiasDiff (for 1D weights like norm) uses weight wrapper, not bypass + # Only use bypass for adapters that have h() method (lora/lokr/oft) + if isinstance(adapter, BiasDiff): + mp.add_weight_wrapper(key, adapter) + logging.debug(f"[BypassMode] Added 1D weight adapter (weight wrapper) for {key}") + else: + bypass_manager.add_adapter(key, adapter, strength=1.0) + logging.debug(f"[BypassMode] Added weight adapter (bypass) for {key}") + + if hasattr(m, "bias") and m.bias is not None: + # Bias adapters still use weight wrapper (bias is usually not quantized) + bias_adapter, bias_params = _create_bias_adapter(m, n, lora_dtype) + lora_sd.update(bias_params) + key = f"{n}.bias" + mp.add_weight_wrapper(key, bias_adapter) + all_weight_adapters.append(bias_adapter) + logging.debug(f"[BypassMode] Added bias adapter (weight wrapper) for {key}") + + return lora_sd, all_weight_adapters, bypass_manager + + def _create_optimizer(optimizer_name, parameters, learning_rate): """Create optimizer based on name. @@ -884,11 +947,13 @@ class TrainLoraNode(io.ComfyNode): default=False, tooltip="Enable resolution bucket mode. When enabled, expects pre-bucketed latents from ResolutionBucket node.", ), + io.Boolean.Input( + "bypass_mode", + default=False, + tooltip="Enable bypass mode for training. When enabled, adapters are applied via forward hooks instead of weight modification. Useful for quantized models where weights cannot be directly modified.", + ), ], outputs=[ - io.Model.Output( - display_name="model", tooltip="Model with LoRA applied" - ), io.Custom("LORA_MODEL").Output( display_name="lora", tooltip="LoRA weights" ), @@ -919,6 +984,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing, existing_lora, bucket_mode, + bypass_mode, ): # Extract scalars from lists (due to is_input_list=True) model = model[0] @@ -936,6 +1002,7 @@ class TrainLoraNode(io.ComfyNode): gradient_checkpointing = gradient_checkpointing[0] existing_lora = existing_lora[0] bucket_mode = bucket_mode[0] + bypass_mode = bypass_mode[0] # Process latents based on mode if bucket_mode: @@ -968,9 +1035,16 @@ class TrainLoraNode(io.ComfyNode): existing_weights, existing_steps = _load_existing_lora(existing_lora) # Setup LoRA adapters - lora_sd, all_weight_adapters = _setup_lora_adapters( - mp, existing_weights, algorithm, lora_dtype, rank - ) + bypass_manager = None + if bypass_mode: + logging.debug("Using bypass mode for training") + lora_sd, all_weight_adapters, bypass_manager = _setup_lora_adapters_bypass( + mp, existing_weights, algorithm, lora_dtype, rank + ) + else: + lora_sd, all_weight_adapters = _setup_lora_adapters( + mp, existing_weights, algorithm, lora_dtype, rank + ) # Create optimizer and loss function optimizer = _create_optimizer( @@ -1029,6 +1103,14 @@ class TrainLoraNode(io.ComfyNode): guider = TrainGuider(mp) guider.set_conds(positive) + # Inject bypass hooks if bypass mode is enabled + bypass_injections = None + if bypass_manager is not None: + bypass_injections = bypass_manager.create_injections(mp.model) + for injection in bypass_injections: + injection.inject(mp) + logging.debug(f"[BypassMode] Injected {bypass_manager.get_hook_count()} bypass hooks") + # Run training loop try: _run_training_loop( @@ -1041,6 +1123,11 @@ class TrainLoraNode(io.ComfyNode): multi_res, ) finally: + # Eject bypass hooks if they were injected + if bypass_injections is not None: + for injection in bypass_injections: + injection.eject(mp) + logging.debug("[BypassMode] Ejected bypass hooks") for m in mp.model.modules(): unpatch(m) del train_sampler, optimizer @@ -1052,7 +1139,9 @@ class TrainLoraNode(io.ComfyNode): for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) + # mp in train node is highly specialized for training + # use it in inference will result in bad behavior so we don't return it + return io.NodeOutput(lora_sd, loss_map, steps + existing_steps) class LoraModelLoader(io.ComfyNode):# @@ -1101,6 +1190,7 @@ class SaveLoRA(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SaveLoRA", + search_aliases=["export lora"], display_name="Save LoRA Weights", category="loaders", is_experimental=True, @@ -1144,6 +1234,7 @@ class LossGraphNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LossGraphNode", + search_aliases=["training chart", "training visualization", "plot loss"], display_name="Plot Loss Graph", category="training", is_experimental=True, diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index 4d62b87be..97b9e948d 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -53,6 +53,7 @@ class ImageUpscaleWithModel(io.ComfyNode): node_id="ImageUpscaleWithModel", display_name="Upscale Image (using Model)", category="image/upscaling", + search_aliases=["upscale", "upscaler", "upsc", "enlarge image", "super resolution", "hires", "superres", "increase resolution"], inputs=[ io.UpscaleModel.Input("upscale_model"), io.Image.Input("image"), @@ -78,18 +79,20 @@ class ImageUpscaleWithModel(io.ComfyNode): overlap = 32 oom = True - while oom: - try: - steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) - pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) - oom = False - except model_management.OOM_EXCEPTION as e: - tile //= 2 - if tile < 128: - raise e + try: + while oom: + try: + steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) + pbar = comfy.utils.ProgressBar(steps) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + oom = False + except model_management.OOM_EXCEPTION as e: + tile //= 2 + if tile < 128: + raise e + finally: + upscale_model.to("cpu") - upscale_model.to("cpu") s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) return io.NodeOutput(s) diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index c609e03da..ccf7b63d3 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -16,6 +16,7 @@ class SaveWEBM(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SaveWEBM", + search_aliases=["export webm"], category="image/video", is_experimental=True, inputs=[ @@ -69,6 +70,7 @@ class SaveVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SaveVideo", + search_aliases=["export video"], display_name="Save Video", category="image/video", description="Saves the input images to your ComfyUI output directory.", @@ -116,6 +118,7 @@ class CreateVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CreateVideo", + search_aliases=["images to video"], display_name="Create Video", category="image/video", description="Create a video from images.", @@ -140,6 +143,7 @@ class GetVideoComponents(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="GetVideoComponents", + search_aliases=["extract frames", "split video", "video to images", "demux"], display_name="Get Video Components", category="image/video", description="Extracts all components from a video: frames, audio, and framerate.", @@ -167,6 +171,7 @@ class LoadVideo(io.ComfyNode): files = folder_paths.filter_files_content_types(files, ["video"]) return io.Schema( node_id="LoadVideo", + search_aliases=["import video", "open video", "video file"], display_name="Load Video", category="image/video", inputs=[ diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index d32aad98e..2ff012134 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -8,9 +8,10 @@ import comfy.latent_formats import comfy.clip_vision import json import numpy as np -from typing import Tuple +from typing import Tuple, TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import logging class WanImageToVideo(io.ComfyNode): @classmethod @@ -286,6 +287,7 @@ class WanVaceToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanVaceToVideo", + search_aliases=["video conditioning", "video control"], category="conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), @@ -704,6 +706,7 @@ class WanTrackToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanTrackToVideo", + search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"], category="conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), @@ -1288,6 +1291,171 @@ class Wan22ImageToVideoLatent(io.ComfyNode): return io.NodeOutput(out_latent) +from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleWrapper, MultiTalkCrossAttnPatch, MultiTalkGetAttnMapPatch, project_audio_features +class WanInfiniteTalkToVideo(io.ComfyNode): + class DCValues(TypedDict): + mode: str + audio_encoder_output_2: io.AudioEncoderOutput.Type + mask: io.Mask.Type + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="WanInfiniteTalkToVideo", + category="conditioning/video_models", + inputs=[ + io.DynamicCombo.Input("mode", options=[ + io.DynamicCombo.Option("single_speaker", []), + io.DynamicCombo.Option("two_speakers", [ + io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True), + io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."), + io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."), + ]), + ]), + io.Model.Input("model"), + io.ModelPatch.Input("model_patch"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Image.Input("start_image", optional=True), + io.AudioEncoderOutput.Input("audio_encoder_output_1"), + io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."), + io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01), + io.Image.Input("previous_frames", optional=True), + ], + outputs=[ + io.Model.Output(display_name="model"), + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + io.Int.Output(display_name="trim_image"), + ], + ) + + @classmethod + def execute(cls, mode: DCValues, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, + start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput: + + if previous_frames is not None and previous_frames.shape[0] < motion_frame_count: + raise ValueError("Not enough previous frames provided.") + + if mode["mode"] == "two_speakers": + audio_encoder_output_2 = mode["audio_encoder_output_2"] + mask_1 = mode["mask_1"] + mask_2 = mode["mask_2"] + + if audio_encoder_output_2 is not None: + if mask_1 is None or mask_2 is None: + raise ValueError("Masks must be provided if two audio encoder outputs are used.") + + ref_masks = None + if mask_1 is not None and mask_2 is not None: + if audio_encoder_output_2 is None: + raise ValueError("Second audio encoder output must be provided if two masks are used.") + ref_masks = torch.cat([mask_1, mask_2]) + + latent = torch.zeros([1, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 + image[:start_image.shape[0]] = start_image + + concat_latent_image = vae.encode(image[:, :, :, :3]) + concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + model_patched = model.clone() + + encoded_audio_list = [] + seq_lengths = [] + + for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]: + if audio_encoder_output is None: + continue + all_layers = audio_encoder_output["encoded_audio_all_layers"] + encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512] + encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512] + encoded_audio_list.append(encoded_audio) + seq_lengths.append(encoded_audio.shape[0]) + + # Pad / combine depending on multi_audio_type + multi_audio_type = "add" + if len(encoded_audio_list) > 1: + if multi_audio_type == "para": + max_len = max(seq_lengths) + padded = [] + for emb in encoded_audio_list: + if emb.shape[0] < max_len: + pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype) + emb = torch.cat([emb, pad], dim=0) + padded.append(emb) + encoded_audio_list = padded + elif multi_audio_type == "add": + total_len = sum(seq_lengths) + full_list = [] + offset = 0 + for emb, seq_len in zip(encoded_audio_list, seq_lengths): + full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype) + full[offset:offset+seq_len] = emb + full_list.append(full) + offset += seq_len + encoded_audio_list = full_list + + token_ref_target_masks = None + if ref_masks is not None: + token_ref_target_masks = torch.nn.functional.interpolate( + ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0] + token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1) + + # when extending from previous frames + if previous_frames is not None: + motion_frames = comfy.utils.common_upscale(previous_frames[-motion_frame_count:].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + frame_offset = previous_frames.shape[0] - motion_frame_count + + audio_start = frame_offset + audio_end = audio_start + length + logging.info(f"InfiniteTalk: Processing audio frames {audio_start} - {audio_end}") + + motion_frames_latent = vae.encode(motion_frames[:, :, :, :3]) + trim_image = motion_frame_count + else: + audio_start = trim_image = 0 + audio_end = length + motion_frames_latent = concat_latent_image[:, :, :1] + + audio_embed = project_audio_features(model_patch.model.audio_proj, encoded_audio_list, audio_start, audio_end).to(model_patched.model_dtype()) + model_patched.model_options["transformer_options"]["audio_embeds"] = audio_embed + + # add outer sample wrapper + model_patched.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, + "infinite_talk_outer_sample", + InfiniteTalkOuterSampleWrapper( + motion_frames_latent, + model_patch, + is_extend=previous_frames is not None, + )) + # add cross-attention patch + model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch") + if token_ref_target_masks is not None: + model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch") + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image) + + class WanExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -1307,6 +1475,7 @@ class WanExtension(ComfyExtension): WanHuMoImageToVideo, WanAnimateToVideo, Wan22ImageToVideoLatent, + WanInfiniteTalkToVideo, ] async def comfy_entrypoint() -> WanExtension: diff --git a/comfy_extras/nodes_wanmove.py b/comfy_extras/nodes_wanmove.py index 5f39afa46..d60baf230 100644 --- a/comfy_extras/nodes_wanmove.py +++ b/comfy_extras/nodes_wanmove.py @@ -324,6 +324,7 @@ class GenerateTracks(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="GenerateTracks", + search_aliases=["motion paths", "camera movement", "trajectory"], category="conditioning/video_models", inputs=[ io.Int.Input("width", default=832, min=16, max=4096, step=16), diff --git a/comfy_extras/nodes_webcam.py b/comfy_extras/nodes_webcam.py index 5bf80b4c6..6349ac017 100644 --- a/comfy_extras/nodes_webcam.py +++ b/comfy_extras/nodes_webcam.py @@ -5,6 +5,7 @@ MAX_RESOLUTION = nodes.MAX_RESOLUTION class WebcamCapture(nodes.LoadImage): + SEARCH_ALIASES = ["camera input", "live capture", "camera feed", "snapshot"] @classmethod def INPUT_TYPES(s): return { diff --git a/comfy_extras/nodes_zimage.py b/comfy_extras/nodes_zimage.py new file mode 100644 index 000000000..2ee3c43b1 --- /dev/null +++ b/comfy_extras/nodes_zimage.py @@ -0,0 +1,88 @@ +import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import math +import comfy.utils + + +class TextEncodeZImageOmni(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeZImageOmni", + category="advanced/conditioning", + is_experimental=True, + inputs=[ + io.Clip.Input("clip"), + io.ClipVision.Input("image_encoder", optional=True), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Boolean.Input("auto_resize_images", default=True), + io.Vae.Input("vae", optional=True), + io.Image.Input("image1", optional=True), + io.Image.Input("image2", optional=True), + io.Image.Input("image3", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, prompt, image_encoder=None, auto_resize_images=True, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput: + ref_latents = [] + images = list(filter(lambda a: a is not None, [image1, image2, image3])) + + prompt_list = [] + template = None + if len(images) > 0: + prompt_list = ["<|im_start|>user\n<|vision_start|>"] + prompt_list += ["<|vision_end|><|vision_start|>"] * (len(images) - 1) + prompt_list += ["<|vision_end|><|im_end|>"] + template = "<|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n<|vision_start|>" + + encoded_images = [] + + for i, image in enumerate(images): + if image_encoder is not None: + encoded_images.append(image_encoder.encode_image(image)) + + if vae is not None: + if auto_resize_images: + samples = image.movedim(-1, 1) + total = int(1024 * 1024) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by / 8.0) * 8 + height = round(samples.shape[2] * scale_by / 8.0) * 8 + + image = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1) + ref_latents.append(vae.encode(image)) + + tokens = clip.tokenize(prompt, llama_template=template) + conditioning = clip.encode_from_tokens_scheduled(tokens) + + extra_text_embeds = [] + for p in prompt_list: + tokens = clip.tokenize(p, llama_template="{}") + text_embeds = clip.encode_from_tokens_scheduled(tokens) + extra_text_embeds.append(text_embeds[0][0]) + + if len(ref_latents) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True) + if len(encoded_images) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"clip_vision_outputs": encoded_images}, append=True) + if len(extra_text_embeds) > 0: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents_text_embeds": extra_text_embeds}, append=True) + + return io.NodeOutput(conditioning) + + +class ZImageExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TextEncodeZImageOmni, + ] + + +async def comfy_entrypoint() -> ZImageExtension: + return ZImageExtension() diff --git a/comfyui_version.py b/comfyui_version.py index 1ed60fe5c..b1ebaa115 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.7.0" +__version__ = "0.11.1" diff --git a/cuda_malloc.py b/cuda_malloc.py index ee2bc4b69..b2182df37 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -1,8 +1,10 @@ import os import importlib.util -from comfy.cli_args import args, PerformanceFeature +from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram import subprocess +import comfy_aimdo.control + #Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import. def get_gpu_names(): if os.name == 'nt': @@ -85,8 +87,14 @@ if not args.cuda_malloc: except: pass +if enables_dynamic_vram() and comfy_aimdo.control.init(): + args.cuda_malloc = False + os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "" -if args.cuda_malloc and not args.disable_cuda_malloc: +if args.disable_cuda_malloc: + args.cuda_malloc = False + +if args.cuda_malloc: env_var = os.environ.get('PYTORCH_CUDA_ALLOC_CONF', None) if env_var is None: env_var = "backend:cudaMallocAsync" diff --git a/execution.py b/execution.py index 648f204ec..3dbab82e6 100644 --- a/execution.py +++ b/execution.py @@ -9,9 +9,11 @@ import traceback from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union import asyncio +from contextlib import nullcontext import torch +import comfy.memory_management import comfy.model_management from latent_preview import set_preview_method import nodes @@ -175,7 +177,7 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= continue obj = cached.outputs[output_index] input_data_all[x] = obj - elif input_category is not None: + elif input_category is not None or (is_v3 and class_def.ACCEPT_ALL_INPUTS): input_data_all[x] = [input_data] if is_v3: @@ -515,7 +517,19 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + + #Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows + #will cause all sorts of incompatible memory shapes to fragment the pytorch alloc + #that we just want to cull out each model run. + allocator = comfy.memory_management.aimdo_allocator + with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())): + try: + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) + finally: + if allocator is not None: + comfy.model_management.reset_cast_buffers() + torch.cuda.synchronize() + if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -1000,22 +1014,34 @@ async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[ outputs = set() for x in prompt: if 'class_type' not in prompt[x]: + node_data = prompt[x] + node_title = node_data.get('_meta', {}).get('title') error = { - "type": "invalid_prompt", - "message": "Cannot execute because a node is missing the class_type property.", + "type": "missing_node_type", + "message": f"Node '{node_title or f'ID #{x}'}' has no class_type. The workflow may be corrupted or a custom node is missing.", "details": f"Node ID '#{x}'", - "extra_info": {} + "extra_info": { + "node_id": x, + "class_type": None, + "node_title": node_title + } } return (False, error, [], {}) class_type = prompt[x]['class_type'] class_ = nodes.NODE_CLASS_MAPPINGS.get(class_type, None) if class_ is None: + node_data = prompt[x] + node_title = node_data.get('_meta', {}).get('title', class_type) error = { - "type": "invalid_prompt", - "message": f"Cannot execute because node {class_type} does not exist.", + "type": "missing_node_type", + "message": f"Node '{node_title}' not found. The custom node may not be installed.", "details": f"Node ID '#{x}'", - "extra_info": {} + "extra_info": { + "node_id": x, + "class_type": class_type, + "node_title": node_title + } } return (False, error, [], {}) diff --git a/latent_preview.py b/latent_preview.py index d52e3f7a1..a9d777661 100644 --- a/latent_preview.py +++ b/latent_preview.py @@ -11,7 +11,7 @@ import logging default_preview_method = args.preview_method MAX_PREVIEW_RESOLUTION = args.preview_size -VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] def preview_to_image(latent_image, do_scale=True): if do_scale: diff --git a/main.py b/main.py index 0e07a95da..b8c951375 100644 --- a/main.py +++ b/main.py @@ -5,8 +5,9 @@ import os import importlib.util import folder_paths import time -from comfy.cli_args import args +from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger +from app.assets.scanner import seed_assets import itertools import utils.extra_config import logging @@ -172,6 +173,7 @@ import gc if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") + import comfy.utils import execution @@ -183,6 +185,33 @@ import comfyui_version import app.logger import hook_breaker_ac10a0 +import comfy.memory_management +import comfy.model_patcher + +import comfy_aimdo.control +import comfy_aimdo.torch + +if enables_dynamic_vram(): + if comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): + if args.verbose == 'DEBUG': + comfy_aimdo.control.set_log_debug() + elif args.verbose == 'CRITICAL': + comfy_aimdo.control.set_log_critical() + elif args.verbose == 'ERROR': + comfy_aimdo.control.set_log_error() + elif args.verbose == 'WARNING': + comfy_aimdo.control.set_log_warning() + else: #INFO + comfy_aimdo.control.set_log_info() + + comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic + comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator() + logging.info("DynamicVRAM support detected and enabled") + else: + logging.info("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") + comfy.memory_management.aimdo_allocator = None + + def cuda_malloc_warning(): device = comfy.model_management.get_torch_device() device_name = comfy.model_management.get_torch_device_name(device) @@ -324,6 +353,8 @@ def setup_database(): from app.database.db import init_db, dependencies_available if dependencies_available(): init_db() + if not args.disable_assets_autoscan: + seed_assets(["models"], enable_logging=True) except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") diff --git a/manager_requirements.txt b/manager_requirements.txt index 6585b0c19..c420cc48e 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.4 +comfyui_manager==4.1b1 diff --git a/nodes.py b/nodes.py index 662907ae6..1cb43d9e2 100644 --- a/nodes.py +++ b/nodes.py @@ -5,6 +5,7 @@ import torch import os import sys import json +import glob import hashlib import inspect import traceback @@ -69,6 +70,7 @@ class CLIPTextEncode(ComfyNodeABC): CATEGORY = "conditioning" DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." + SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"] def encode(self, clip, text): if clip is None: @@ -85,11 +87,14 @@ class ConditioningCombine: FUNCTION = "combine" CATEGORY = "conditioning" + SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"] def combine(self, conditioning_1, conditioning_2): return (conditioning_1 + conditioning_2, ) class ConditioningAverage : + SEARCH_ALIASES = ["blend prompts", "interpolate conditioning", "mix prompts", "style fusion", "weighted blend"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning_to": ("CONDITIONING", ), "conditioning_from": ("CONDITIONING", ), @@ -156,6 +161,8 @@ class ConditioningConcat: return (out, ) class ConditioningSetArea: + SEARCH_ALIASES = ["regional prompt", "area prompt", "spatial conditioning", "localized prompt"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -214,6 +221,8 @@ class ConditioningSetAreaStrength: class ConditioningSetMask: + SEARCH_ALIASES = ["masked prompt", "regional inpaint conditioning", "mask conditioning"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -239,6 +248,8 @@ class ConditioningSetMask: return (c, ) class ConditioningZeroOut: + SEARCH_ALIASES = ["null conditioning", "clear conditioning"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", )}} @@ -293,9 +304,14 @@ class VAEDecode: CATEGORY = "latent" DESCRIPTION = "Decodes latent images back into pixel space images." + SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"] def decode(self, vae, samples): - images = vae.decode(samples["samples"]) + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + + images = vae.decode(latent) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -341,6 +357,7 @@ class VAEEncode: FUNCTION = "encode" CATEGORY = "latent" + SEARCH_ALIASES = ["encode", "encode image", "image to latent"] def encode(self, vae, pixels): t = vae.encode(pixels) @@ -374,14 +391,15 @@ class VAEEncodeForInpaint: CATEGORY = "latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): - x = (pixels.shape[1] // vae.downscale_ratio) * vae.downscale_ratio - y = (pixels.shape[2] // vae.downscale_ratio) * vae.downscale_ratio + downscale_ratio = vae.spacial_compression_encode() + x = (pixels.shape[1] // downscale_ratio) * downscale_ratio + y = (pixels.shape[2] // downscale_ratio) * downscale_ratio mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(pixels.shape[1], pixels.shape[2]), mode="bilinear") pixels = pixels.clone() if pixels.shape[1] != x or pixels.shape[2] != y: - x_offset = (pixels.shape[1] % vae.downscale_ratio) // 2 - y_offset = (pixels.shape[2] % vae.downscale_ratio) // 2 + x_offset = (pixels.shape[1] % downscale_ratio) // 2 + y_offset = (pixels.shape[2] % downscale_ratio) // 2 pixels = pixels[:,x_offset:x + x_offset, y_offset:y + y_offset,:] mask = mask[:,:,x_offset:x + x_offset, y_offset:y + y_offset] @@ -457,6 +475,8 @@ class InpaintModelConditioning: class SaveLatent: + SEARCH_ALIASES = ["export latent"] + def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -508,6 +528,8 @@ class SaveLatent: class LoadLatent: + SEARCH_ALIASES = ["import latent", "open latent"] + @classmethod def INPUT_TYPES(s): input_dir = folder_paths.get_input_directory() @@ -544,6 +566,8 @@ class LoadLatent: class CheckpointLoader: + SEARCH_ALIASES = ["load model", "model loader"] + @classmethod def INPUT_TYPES(s): return {"required": { "config_name": (folder_paths.get_filename_list("configs"), ), @@ -575,6 +599,7 @@ class CheckpointLoaderSimple: CATEGORY = "loaders" DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." + SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"] def load_checkpoint(self, ckpt_name): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) @@ -582,6 +607,8 @@ class CheckpointLoaderSimple: return out[:3] class DiffusersLoader: + SEARCH_ALIASES = ["load diffusers model"] + @classmethod def INPUT_TYPES(cls): paths = [] @@ -661,6 +688,7 @@ class LoraLoader: CATEGORY = "loaders" DESCRIPTION = "LoRAs are used to modify diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together." + SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"] def load_lora(self, model, clip, lora_name, strength_model, strength_clip): if strength_model == 0 and strength_clip == 0: @@ -695,7 +723,7 @@ class LoraLoaderModelOnly(LoraLoader): return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) class VAELoader: - video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod def vae_list(s): @@ -783,6 +811,7 @@ class VAELoader: #TODO: scale factor? def load_vae(self, vae_name): + metadata = None if vae_name == "pixel_space": sd = {} sd["pixel_space_vae"] = torch.tensor(1.0) @@ -793,8 +822,8 @@ class VAELoader: vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name) else: vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) - sd = comfy.utils.load_torch_file(vae_path) - vae = comfy.sd.VAE(sd=sd) + sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) + vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae.throw_exception_if_invalid() return (vae,) @@ -807,6 +836,7 @@ class ControlNetLoader: FUNCTION = "load_controlnet" CATEGORY = "loaders" + SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"] def load_controlnet(self, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) @@ -883,6 +913,7 @@ class ControlNetApplyAdvanced: FUNCTION = "apply_controlnet" CATEGORY = "conditioning/controlnet" + SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"] def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]): if strength == 0: @@ -970,7 +1001,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -1048,6 +1079,8 @@ class StyleModelLoader: class StyleModelApply: + SEARCH_ALIASES = ["style transfer"] + @classmethod def INPUT_TYPES(s): return {"required": {"conditioning": ("CONDITIONING", ), @@ -1193,13 +1226,16 @@ class EmptyLatentImage: CATEGORY = "latent" DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling." + SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] def generate(self, width, height, batch_size=1): latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) - return ({"samples":latent}, ) + return ({"samples": latent, "downscale_ratio_spacial": 8}, ) class LatentFromBatch: + SEARCH_ALIASES = ["select from batch", "pick latent", "batch subset"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -1232,6 +1268,8 @@ class LatentFromBatch: return (s,) class RepeatLatentBatch: + SEARCH_ALIASES = ["duplicate latent", "clone latent"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -1258,6 +1296,8 @@ class RepeatLatentBatch: return (s,) class LatentUpscale: + SEARCH_ALIASES = ["enlarge latent", "resize latent"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"] crop_methods = ["disabled", "center"] @@ -1292,6 +1332,8 @@ class LatentUpscale: return (s,) class LatentUpscaleBy: + SEARCH_ALIASES = ["enlarge latent", "resize latent", "scale latent"] + upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "bislerp"] @classmethod @@ -1335,6 +1377,8 @@ class LatentRotate: return (s,) class LatentFlip: + SEARCH_ALIASES = ["mirror latent"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -1355,6 +1399,8 @@ class LatentFlip: return (s,) class LatentComposite: + SEARCH_ALIASES = ["overlay latent", "layer latent", "paste latent"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples_to": ("LATENT",), @@ -1397,6 +1443,8 @@ class LatentComposite: return (samples_out,) class LatentBlend: + SEARCH_ALIASES = ["mix latents", "interpolate latents"] + @classmethod def INPUT_TYPES(s): return {"required": { @@ -1438,6 +1486,8 @@ class LatentBlend: raise ValueError(f"Unsupported blend mode: {mode}") class LatentCrop: + SEARCH_ALIASES = ["trim latent", "cut latent"] + @classmethod def INPUT_TYPES(s): return {"required": { "samples": ("LATENT",), @@ -1488,7 +1538,7 @@ class SetLatentNoiseMask: def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False): latent_image = latent["samples"] - latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None)) if disable_noise: noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") @@ -1506,6 +1556,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples return (out, ) @@ -1533,6 +1584,7 @@ class KSampler: CATEGORY = "sampling" DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." + SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"] def sample(self, model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0): return common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise) @@ -1597,6 +1649,7 @@ class SaveImage: CATEGORY = "image" DESCRIPTION = "Saves the input images to your ComfyUI output directory." + SEARCH_ALIASES = ["save", "save image", "export image", "output image", "write image", "download"] def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): filename_prefix += self.prefix_append @@ -1633,6 +1686,8 @@ class PreviewImage(SaveImage): self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.compress_level = 1 + SEARCH_ALIASES = ["preview", "preview image", "show image", "view image", "display image", "image viewer"] + @classmethod def INPUT_TYPES(s): return {"required": @@ -1651,6 +1706,7 @@ class LoadImage: } CATEGORY = "image" + SEARCH_ALIASES = ["load image", "open image", "import image", "image input", "upload image", "read image", "image loader"] RETURN_TYPES = ("IMAGE", "MASK") FUNCTION = "load_image" @@ -1718,6 +1774,8 @@ class LoadImage: return True class LoadImageMask: + SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] + _color_channels = ["alpha", "red", "green", "blue"] @classmethod def INPUT_TYPES(s): @@ -1768,6 +1826,8 @@ class LoadImageMask: class LoadImageOutput(LoadImage): + SEARCH_ALIASES = ["output image", "previous generation"] + @classmethod def INPUT_TYPES(s): return { @@ -1803,6 +1863,7 @@ class ImageScale: FUNCTION = "upscale" CATEGORY = "image/upscaling" + SEARCH_ALIASES = ["resize", "resize image", "scale image", "image resize", "zoom", "zoom in", "change size"] def upscale(self, image, upscale_method, width, height, crop): if width == 0 and height == 0: @@ -1840,6 +1901,7 @@ class ImageScaleBy: return (s,) class ImageInvert: + SEARCH_ALIASES = ["reverse colors"] @classmethod def INPUT_TYPES(s): @@ -1855,6 +1917,7 @@ class ImageInvert: return (s,) class ImageBatch: + SEARCH_ALIASES = ["combine images", "merge images", "stack images"] @classmethod def INPUT_TYPES(s): @@ -1900,6 +1963,7 @@ class EmptyImage: return (torch.cat((r, g, b), dim=-1), ) class ImagePadForOutpaint: + SEARCH_ALIASES = ["extend canvas", "expand image"] @classmethod def INPUT_TYPES(s): @@ -2041,7 +2105,8 @@ NODE_DISPLAY_NAME_MAPPINGS = { "CheckpointLoader": "Load Checkpoint With Config (DEPRECATED)", "CheckpointLoaderSimple": "Load Checkpoint", "VAELoader": "Load VAE", - "LoraLoader": "Load LoRA", + "LoraLoader": "Load LoRA (Model and CLIP)", + "LoraLoaderModelOnly": "Load LoRA", "CLIPLoader": "Load CLIP", "ControlNetLoader": "Load ControlNet Model", "DiffControlNetLoader": "Load ControlNet Model (diff)", @@ -2331,6 +2396,8 @@ async def init_builtin_extra_nodes(): "nodes_mochi.py", "nodes_slg.py", "nodes_mahiro.py", + "nodes_lt_upsampler.py", + "nodes_lt_audio.py", "nodes_lt.py", "nodes_hooks.py", "nodes_load_3d.py", @@ -2363,6 +2430,10 @@ async def init_builtin_extra_nodes(): "nodes_nop.py", "nodes_kandinsky5.py", "nodes_wanmove.py", + "nodes_image_compare.py", + "nodes_zimage.py", + "nodes_lora_debug.py", + "nodes_color.py" ] import_failed = [] @@ -2375,37 +2446,12 @@ async def init_builtin_extra_nodes(): async def init_builtin_api_nodes(): api_nodes_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy_api_nodes") - api_nodes_files = [ - "nodes_ideogram.py", - "nodes_openai.py", - "nodes_minimax.py", - "nodes_veo2.py", - "nodes_kling.py", - "nodes_bfl.py", - "nodes_bytedance.py", - "nodes_ltxv.py", - "nodes_luma.py", - "nodes_recraft.py", - "nodes_pixverse.py", - "nodes_stability.py", - "nodes_runway.py", - "nodes_sora.py", - "nodes_topaz.py", - "nodes_tripo.py", - "nodes_moonvalley.py", - "nodes_rodin.py", - "nodes_gemini.py", - "nodes_vidu.py", - "nodes_wan.py", - ] - - if not await load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"): - return api_nodes_files + api_nodes_files = sorted(glob.glob(os.path.join(api_nodes_dir, "nodes_*.py"))) import_failed = [] for node_file in api_nodes_files: - if not await load_custom_node(os.path.join(api_nodes_dir, node_file), module_parent="comfy_api_nodes"): - import_failed.append(node_file) + if not await load_custom_node(node_file, module_parent="comfy_api_nodes"): + import_failed.append(os.path.basename(node_file)) return import_failed diff --git a/pyproject.toml b/pyproject.toml index 60378de1e..042f124e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [project] name = "ComfyUI" -version = "0.7.0" +version = "0.11.1" readme = "README.md" license = { file = "LICENSE" } -requires-python = ">=3.9" +requires-python = ">=3.10" [project.urls] homepage = "https://www.comfy.org/" diff --git a/requirements.txt b/requirements.txt index 3a05799eb..3ca417dd8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -comfyui-frontend-package==1.35.9 -comfyui-workflow-templates==0.7.65 -comfyui-embedded-docs==0.3.1 +comfyui-frontend-package==1.37.11 +comfyui-workflow-templates==0.8.27 +comfyui-embedded-docs==0.4.0 torch torchsde torchvision @@ -21,6 +21,9 @@ psutil alembic SQLAlchemy av>=14.2.0 +comfy-kitchen>=0.2.7 +comfy-aimdo>=0.1.7 +requests #non essential dependencies: kornia>=0.7.1 diff --git a/server.py b/server.py index 70c8b5e3b..2300393b2 100644 --- a/server.py +++ b/server.py @@ -33,6 +33,8 @@ import node_helpers from comfyui_version import __version__ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal +from app.assets.scanner import seed_assets +from app.assets.api.routes import register_assets_system from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -184,7 +186,7 @@ def create_block_external_middleware(): else: response = await handler(request) - response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self' data:; frame-src 'self'; object-src 'self';" return response return block_external_middleware @@ -235,6 +237,7 @@ class PromptServer(): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") + register_assets_system(self.app, self.user_manager) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None @@ -653,6 +656,7 @@ class PromptServer(): info = {} info['input'] = obj_class.INPUT_TYPES() info['input_order'] = {key: list(value.keys()) for (key, value) in obj_class.INPUT_TYPES().items()} + info['is_input_list'] = getattr(obj_class, "INPUT_IS_LIST", False) info['output'] = obj_class.RETURN_TYPES info['output_is_list'] = obj_class.OUTPUT_IS_LIST if hasattr(obj_class, 'OUTPUT_IS_LIST') else [False] * len(obj_class.RETURN_TYPES) info['output_name'] = obj_class.RETURN_NAMES if hasattr(obj_class, 'RETURN_NAMES') else info['output'] @@ -676,13 +680,21 @@ class PromptServer(): info['deprecated'] = True if getattr(obj_class, "EXPERIMENTAL", False): info['experimental'] = True + if getattr(obj_class, "DEV_ONLY", False): + info['dev_only'] = True if hasattr(obj_class, 'API_NODE'): info['api_node'] = obj_class.API_NODE + + info['search_aliases'] = getattr(obj_class, 'SEARCH_ALIASES', []) return info @routes.get("/object_info") async def get_object_info(request): + try: + seed_assets(["models"]) + except Exception as e: + logging.error(f"Failed to seed assets: {e}") with folder_paths.cache_helper: out = {} for x in nodes.NODE_CLASS_MAPPINGS: diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py new file mode 100644 index 000000000..0a57dd7b5 --- /dev/null +++ b/tests-unit/assets_test/conftest.py @@ -0,0 +1,271 @@ +import contextlib +import json +import os +import socket +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import Callable, Iterator, Optional + +import pytest +import requests + + +def pytest_addoption(parser: pytest.Parser) -> None: + """ + Allow overriding the database URL used by the spawned ComfyUI process. + Priority: + 1) --db-url command line option + 2) ASSETS_TEST_DB_URL environment variable (used by CI) + 3) default: None (will use file-backed sqlite in temp dir) + """ + parser.addoption( + "--db-url", + action="store", + default=os.environ.get("ASSETS_TEST_DB_URL"), + help="SQLAlchemy DB URL (e.g. sqlite:///path/to/db.sqlite3)", + ) + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _make_base_dirs(root: Path) -> None: + for sub in ("models", "custom_nodes", "input", "output", "temp", "user"): + (root / sub).mkdir(parents=True, exist_ok=True) + + +def _wait_http_ready(base: str, session: requests.Session, timeout: float = 90.0) -> None: + start = time.time() + last_err = None + while time.time() - start < timeout: + try: + r = session.get(base + "/api/assets", timeout=5) + if r.status_code in (200, 400): + return + except Exception as e: + last_err = e + time.sleep(0.25) + raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}") + + +@pytest.fixture(scope="session") +def comfy_tmp_base_dir() -> Path: + env_base = os.environ.get("ASSETS_TEST_BASE_DIR") + created_by_fixture = False + if env_base: + tmp = Path(env_base) + tmp.mkdir(parents=True, exist_ok=True) + else: + tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) + created_by_fixture = True + _make_base_dirs(tmp) + yield tmp + if created_by_fixture: + with contextlib.suppress(Exception): + for p in sorted(tmp.rglob("*"), reverse=True): + if p.is_file() or p.is_symlink(): + p.unlink(missing_ok=True) + for p in sorted(tmp.glob("**/*"), reverse=True): + with contextlib.suppress(Exception): + p.rmdir() + tmp.rmdir() + + +@pytest.fixture(scope="session") +def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest): + """ + Boot ComfyUI subprocess with: + - sandbox base dir + - file-backed sqlite DB in temp dir + - autoscan disabled + Returns (base_url, process, port) + """ + port = _free_port() + db_url = request.config.getoption("--db-url") + if not db_url: + # Use a file-backed sqlite database in the temp directory + db_path = comfy_tmp_base_dir / "assets-test.sqlite3" + db_url = f"sqlite:///{db_path}" + + logs_dir = comfy_tmp_base_dir / "logs" + logs_dir.mkdir(exist_ok=True) + out_log = open(logs_dir / "stdout.log", "w", buffering=1) + err_log = open(logs_dir / "stderr.log", "w", buffering=1) + + comfy_root = Path(__file__).resolve().parent.parent.parent + if not (comfy_root / "main.py").is_file(): + raise FileNotFoundError(f"main.py not found under {comfy_root}") + + proc = subprocess.Popen( + args=[ + sys.executable, + "main.py", + f"--base-directory={str(comfy_tmp_base_dir)}", + f"--database-url={db_url}", + "--disable-assets-autoscan", + "--listen", + "127.0.0.1", + "--port", + str(port), + "--cpu", + ], + stdout=out_log, + stderr=err_log, + cwd=str(comfy_root), + env={**os.environ}, + ) + + for _ in range(50): + if proc.poll() is not None: + out_log.flush() + err_log.flush() + raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}") + time.sleep(0.1) + + base_url = f"http://127.0.0.1:{port}" + try: + with requests.Session() as s: + _wait_http_ready(base_url, s, timeout=90.0) + yield base_url, proc, port + except Exception as e: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=10) + with contextlib.suppress(Exception): + out_log.flush() + err_log.flush() + raise RuntimeError(f"ComfyUI did not become ready: {e}") + + if proc and proc.poll() is None: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=15) + out_log.close() + err_log.close() + + +@pytest.fixture +def http() -> Iterator[requests.Session]: + with requests.Session() as s: + s.timeout = 120 + yield s + + +@pytest.fixture +def api_base(comfy_url_and_proc) -> str: + base_url, _proc, _port = comfy_url_and_proc + return base_url + + +def _post_multipart_asset( + session: requests.Session, + base: str, + *, + name: str, + tags: list[str], + meta: dict, + data: bytes, + extra_fields: Optional[dict] = None, +) -> tuple[int, dict]: + files = {"file": (name, data, "application/octet-stream")} + form_data = { + "tags": json.dumps(tags), + "name": name, + "user_metadata": json.dumps(meta), + } + if extra_fields: + for k, v in extra_fields.items(): + form_data[k] = v + r = session.post(base + "/api/assets", files=files, data=form_data, timeout=120) + return r.status_code, r.json() + + +@pytest.fixture +def make_asset_bytes() -> Callable[[str, int], bytes]: + def _make(name: str, size: int = 8192) -> bytes: + seed = sum(ord(c) for c in name) % 251 + return bytes((i * 31 + seed) % 256 for i in range(size)) + return _make + + +@pytest.fixture +def asset_factory(http: requests.Session, api_base: str): + """ + Returns create(name, tags, meta, data) -> response dict + Tracks created ids and deletes them after the test. + """ + created: list[str] = [] + + def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict: + status, body = _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data) + assert status in (200, 201), body + created.append(body["id"]) + return body + + yield create + + for aid in created: + with contextlib.suppress(Exception): + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + + +@pytest.fixture +def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_base: str) -> dict: + """ + Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/. + Returns response dict with id, asset_hash, tags, etc. + """ + name = "unit_1_example.safetensors" + p = getattr(request, "param", {}) or {} + tags: Optional[list[str]] = p.get("tags") + if tags is None: + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} + files = {"file": (name, b"A" * 4096, "application/octet-stream")} + form_data = { + "tags": json.dumps(tags), + "name": name, + "user_metadata": json.dumps(meta), + } + r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120) + body = r.json() + assert r.status_code == 201, body + return body + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(http: requests.Session, api_base: str): + """Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test.""" + yield + + while True: + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "limit": "500", "sort": "name"}, + timeout=30, + ) + if r.status_code != 200: + break + body = r.json() + ids = [a["id"] for a in body.get("assets", [])] + if not ids: + break + for aid in ids: + with contextlib.suppress(Exception): + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + + +def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: + """Force a fast sync/seed pass by calling the seed endpoint.""" + session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30) + time.sleep(0.2) + + +def get_asset_filename(asset_hash: str, extension: str) -> str: + return asset_hash.removeprefix("blake3:") + extension diff --git a/tests-unit/assets_test/test_assets_missing_sync.py b/tests-unit/assets_test/test_assets_missing_sync.py new file mode 100644 index 000000000..78fa7b404 --- /dev/null +++ b/tests-unit/assets_test/test_assets_missing_sync.py @@ -0,0 +1,348 @@ +import os +import uuid +from pathlib import Path + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_seed_asset_removed_when_file_is_deleted( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Asset without hash (seed) whose file disappears: + after triggering sync_seed_assets, Asset + AssetInfo disappear. + """ + # Create a file directly under input/unit-tests/ so tags include "unit-tests" + case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed" + case_dir.mkdir(parents=True, exist_ok=True) + name = f"seed_{uuid.uuid4().hex[:8]}.bin" + fp = case_dir / name + fp.write_bytes(b"Z" * 2048) + + # Trigger a seed sync so DB sees this path (seed asset => hash is NULL) + trigger_sync_seed_assets(http, api_base) + + # Verify it is visible via API and carries no hash (seed) + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + timeout=120, + ) + body1 = r1.json() + assert r1.status_code == 200 + # there should be exactly one with that name + matches = [a for a in body1.get("assets", []) if a.get("name") == name] + assert matches + assert matches[0].get("asset_hash") is None + asset_info_id = matches[0]["id"] + + # Remove the underlying file and sync again + if fp.exists(): + fp.unlink() + + trigger_sync_seed_assets(http, api_base) + + # It should disappear (AssetInfo and seed Asset gone) + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + timeout=120, + ) + body2 = r2.json() + assert r2.status_code == 200 + matches2 = [a for a in body2.get("assets", []) if a.get("name") == name] + assert not matches2, f"Seed asset {asset_info_id} should be gone after sync" + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to verify and clear missing tags") +def test_hashed_asset_missing_tag_added_then_removed_after_scan( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """Hashed asset with a single cache_state: + 1. delete its file -> sync adds 'missing' + 2. restore file -> sync removes 'missing' + """ + name = "missing_tag_test.png" + tags = ["input", "unit-tests", "msync2"] + data = make_asset_bytes(name, 4096) + a = asset_factory(name, tags, {}, data) + + # Compute its on-disk path and remove it + dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png") + assert dest.exists(), f"Expected asset file at {dest}" + dest.unlink() + + # Fast sync should add 'missing' to the AssetInfo + trigger_sync_seed_assets(http, api_base) + + g1 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion" + + # Restore the file with the exact same content and sync again + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(data) + + trigger_sync_seed_assets(http, api_base) + + g2 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify" + + +def test_hashed_asset_two_asset_infos_both_get_missing( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Hashed asset with a single cache_state, but two AssetInfo rows: + deleting the single file then syncing should add 'missing' to both infos. + """ + # Upload one hashed asset + name = "two_infos_one_path.png" + base_tags = ["input", "unit-tests", "multiinfo"] + created = asset_factory(name, base_tags, {}, b"A" * 2048) + + # Create second AssetInfo for the same Asset via from-hash + payload = { + "hash": created["asset_hash"], + "name": "two_infos_one_path_copy.png", + "tags": base_tags, # keep it in our unit-tests scope for cleanup + "user_metadata": {"k": "v"}, + } + r2 = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120) + b2 = r2.json() + assert r2.status_code == 201, b2 + second_id = b2["id"] + + # Remove the single underlying file + p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png") + assert p.exists() + p.unlink() + + r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags0 = r0.json() + assert r0.status_code == 200, tags0 + byname0 = {t["name"]: t for t in tags0.get("tags", [])} + old_missing = int(byname0.get("missing", {}).get("count", 0)) + + # Sync -> both AssetInfos for this asset must receive 'missing' + trigger_sync_seed_assets(http, api_base) + + ga = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120) + da = ga.json() + assert ga.status_code == 200, da + assert "missing" in set(da.get("tags", [])) + + gb = http.get(f"{api_base}/api/assets/{second_id}", timeout=120) + db = gb.json() + assert gb.status_code == 200, db + assert "missing" in set(db.get("tags", [])) + + # Tag usage for 'missing' increased by exactly 2 (two AssetInfos) + r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags1 = r1.json() + assert r1.status_code == 200, tags1 + byname1 = {t["name"]: t for t in tags1.get("tags", [])} + new_missing = int(byname1.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Hashed asset with two cache_state rows: + 1. delete one file -> sync should NOT add 'missing' + 2. delete second file -> sync should add 'missing' + """ + name = "two_cache_states_partial_delete.png" + tags = ["input", "unit-tests", "dual"] + data = make_asset_bytes(name, 3072) + + created = asset_factory(name, tags, {}, data) + path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png") + assert path1.exists() + + # Create a second on-disk copy under the same root but different subfolder + path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + + # Fast seed so the second path appears (as a seed initially) + trigger_sync_seed_assets(http, api_base) + + # Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner. + run_scan_and_wait("input") + + # Remove only one file and sync -> asset should still be healthy (no 'missing') + path1.unlink() + trigger_sync_seed_assets(http, api_base) + + g1 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains" + + # Baseline 'missing' usage count just before last file removal + r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags0 = r0.json() + assert r0.status_code == 200, tags0 + old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0)) + + # Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo + path2.unlink() + trigger_sync_seed_assets(http, api_base) + + g2 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain" + + # Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset) + r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags1 = r1.json() + assert r1.status_code == 200, tags1 + new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """ + Fast pass alone clears 'missing' when size and mtime match exactly: + 1) upload (hashed), record original mtime_ns + 2) delete -> fast pass adds 'missing' + 3) restore same bytes and set mtime back to the original value + 4) run fast pass again -> 'missing' is removed (no slow scan) + """ + scope = f"fastclear-{uuid.uuid4().hex[:6]}" + name = "fastpass_clear.bin" + data = make_asset_bytes(name, 3072) + + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p = base / get_asset_filename(a["asset_hash"], ".bin") + st0 = p.stat() + orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000)) + + # Delete -> fast pass adds 'missing' + p.unlink() + trigger_sync_seed_assets(http, api_base) + g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" in set(d1.get("tags", [])) + + # Restore same bytes and revert mtime to the original value + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(data) + # set both atime and mtime in ns to ensure exact match + os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns)) + + # Fast pass should clear 'missing' without a scan + trigger_sync_seed_assets(http, api_base) + g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match" + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_fastpass_removes_stale_state_row_no_missing( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two states: + - delete one file + - run fast pass only + Expect: + - asset stays healthy (no 'missing') + - stale AssetCacheState row for the deleted path is removed. + We verify this behaviorally by recreating the deleted path and running fast pass again: + a new *seed* AssetInfo is created, which proves the old state row was not reused. + """ + scope = f"stale-{uuid.uuid4().hex[:6]}" + name = "two_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload hashed asset at path1 + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + base = comfy_tmp_base_dir / root / "unit-tests" / scope + a1_filename = get_asset_filename(a["asset_hash"], ".bin") + p1 = base / a1_filename + assert p1.exists() + + aid = a["id"] + h = a["asset_hash"] + + # Create second state path2, seed+scan to dedupe into the same Asset + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + run_scan_and_wait(root) + + # Delete path1 and run fast pass -> no 'missing' and stale state row should be removed + p1.unlink() + trigger_sync_seed_assets(http, api_base) + g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" not in set(d1.get("tags", [])) + + # Recreate path1 and run fast pass again. + # If the stale state row was removed, a NEW seed AssetInfo will appear for this path. + p1.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + + rl = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}"}, + timeout=120, + ) + bl = rl.json() + assert rl.status_code == 200, bl + items = bl.get("assets", []) + # one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null) + hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)] + assert h in hashes + assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path" + + # Asset identity still healthy + rh = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120) + assert rh.status_code == 200 diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py new file mode 100644 index 000000000..d2b69f475 --- /dev/null +++ b/tests-unit/assets_test/test_crud.py @@ -0,0 +1,306 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + +def test_create_from_hash_success( + http: requests.Session, api_base: str, seeded_asset: dict +): + h = seeded_asset["asset_hash"] + payload = { + "hash": h, + "name": "from_hash_ok.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "from-hash"], + "user_metadata": {"k": "v"}, + } + r1 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + b1 = r1.json() + assert r1.status_code == 201, b1 + assert b1["asset_hash"] == h + assert b1["created_new"] is False + aid = b1["id"] + + # Calling again with the same name should return the same AssetInfo id + r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + b2 = r2.json() + assert r2.status_code == 201, b2 + assert b2["id"] == aid + + +def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # GET detail + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = rg.json() + assert rg.status_code == 200, detail + assert detail["id"] == aid + assert "user_metadata" in detail + assert "filename" in detail["user_metadata"] + + # DELETE + rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + assert rd.status_code == 204 + + # GET again -> 404 + rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + body = rg2.json() + assert rg2.status_code == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +def test_delete_upon_reference_count( + http: requests.Session, api_base: str, seeded_asset: dict +): + # Create a second reference to the same asset via from-hash + src_hash = seeded_asset["asset_hash"] + payload = { + "hash": src_hash, + "name": "unit_ref_copy.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "del-flow"], + "user_metadata": {"note": "copy"}, + } + r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + copy = r2.json() + assert r2.status_code == 201, copy + assert copy["asset_hash"] == src_hash + assert copy["created_new"] is False + + # Delete original reference -> asset identity must remain + aid1 = seeded_asset["id"] + rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120) + assert rd1.status_code == 204 + + rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh1.status_code == 200 # identity still present + + # Delete the last reference with default semantics -> identity and cached files removed + aid2 = copy["id"] + rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120) + assert rd2.status_code == 204 + + rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh2.status_code == 404 # orphan content removed + + +def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + original_tags = seeded_asset["tags"] + + payload = { + "name": "unit_1_renamed.safetensors", + "user_metadata": {"purpose": "updated", "epoch": 2}, + } + ru = http.put(f"{api_base}/api/assets/{aid}", json=payload, timeout=120) + body = ru.json() + assert ru.status_code == 200, body + assert body["name"] == payload["name"] + assert body["tags"] == original_tags # tags unchanged + assert body["user_metadata"]["purpose"] == "updated" + # filename should still be present and normalized by server + assert "filename" in body["user_metadata"] + + +def test_head_asset_by_hash(http: requests.Session, api_base: str, seeded_asset: dict): + h = seeded_asset["asset_hash"] + + # Existing + rh1 = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120) + assert rh1.status_code == 200 + + # Non-existent + rh2 = http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}", timeout=120) + assert rh2.status_code == 404 + + +def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api_base: str): + # Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload. + # requests exposes an empty body for HEAD, so validate status and that there is no payload. + rh = http.head(f"{api_base}/api/assets/hash/not_a_hash", timeout=120) + assert rh.status_code == 400 + body = rh.content + assert body == b"" + + +def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str): + bogus = str(uuid.uuid4()) + r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120) + body = r.json() + assert r.status_code == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +def test_create_from_hash_invalids(http: requests.Session, api_base: str): + # Bad hash algorithm + bad = { + "hash": "sha256:" + "0" * 64, + "name": "x.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Invalid JSON body + r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_JSON" + + +def test_get_update_download_bad_ids(http: requests.Session, api_base: str): + # All endpoints should be not found, as we UUID regex directly in the route definition. + bad_id = "not-a-uuid" + + r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120) + assert r1.status_code == 404 + + r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120) + assert r3.status_code == 404 + + +def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + r = http.put(f"{api_base}/api/assets/{aid}", json={}, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_concurrent_delete_same_asset_info_single_204( + root: str, + http: requests.Session, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Many concurrent DELETE for the same AssetInfo should result in: + - exactly one 204 No Content (the one that actually deleted) + - all others 404 Not Found (row already gone) + """ + scope = f"conc-del-{uuid.uuid4().hex[:6]}" + name = "to_delete.bin" + data = make_asset_bytes(name, 1536) + + created = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = created["id"] + + # Hit the same endpoint N times in parallel. + n_tests = 4 + url = f"{api_base}/api/assets/{aid}?delete_content=false" + + def _do_delete(delete_url): + with requests.Session() as s: + return s.delete(delete_url, timeout=120).status_code + + with ThreadPoolExecutor(max_workers=n_tests) as ex: + statuses = list(ex.map(_do_delete, [url] * n_tests)) + + # Exactly one actual delete, the rest must be 404 + assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}" + assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}" + + # The resource must be gone. + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + assert rg.status_code == 404 + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_metadata_filename_is_set_for_seed_asset_without_hash( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Seed ingest (no hash yet) must compute user_metadata['filename'] immediately.""" + scope = f"seedmeta-{uuid.uuid4().hex[:6]}" + name = "seed_filename.bin" + + base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b" + base.mkdir(parents=True, exist_ok=True) + fp = base / name + fp.write_bytes(b"Z" * 2048) + + trigger_sync_seed_assets(http, api_base) + + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + timeout=120, + ) + body = r1.json() + assert r1.status_code == 200, body + matches = [a for a in body.get("assets", []) if a.get("name") == name] + assert matches, "Seed asset should be visible after sync" + assert matches[0].get("asset_hash") is None # still a seed + aid = matches[0]["id"] + + r2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = r2.json() + assert r2.status_code == 200, detail + filename = (detail.get("user_metadata") or {}).get("filename") + expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/") + assert filename == expected, f"expected filename={expected}, got {filename!r}" + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to retarget cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_metadata_filename_computed_and_updated_on_retarget( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + 1) Ingest under {root}/unit-tests//a/b/ -> filename reflects relative path. + 2) Retarget by copying to {root}/unit-tests//x/, remove old file, + run fast pass + scan -> filename updates to new relative path. + """ + scope = f"meta-fn-{uuid.uuid4().hex[:6]}" + name1 = "compute_metadata_filename.png" + name2 = "compute_changed_metadata_filename.png" + data = make_asset_bytes(name1, 2100) + + # Upload into nested path a/b + a = asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data) + aid = a["id"] + + root_base = comfy_tmp_base_dir / root + p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png")) + assert p1.exists() + + # filename at ingest should be the path relative to root + rel1 = str(p1.relative_to(root_base)).replace("\\", "/") + g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + fn1 = d1["user_metadata"].get("filename") + assert fn1 == rel1 + + # Retarget: copy to x/, remove old, then sync+scan + p2 = root_base / "unit-tests" / scope / "x" / name2 + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + if p1.exists(): + p1.unlink() + + trigger_sync_seed_assets(http, api_base) # seed the new path + run_scan_and_wait(root) # verify/hash and reconcile + + # filename should now point at x/ + rel2 = str(p2.relative_to(root_base)).replace("\\", "/") + g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + fn2 = d2["user_metadata"].get("filename") + assert fn2 == rel2 diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py new file mode 100644 index 000000000..cdebf9082 --- /dev/null +++ b/tests-unit/assets_test/test_downloads.py @@ -0,0 +1,166 @@ +import time +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + +def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # default attachment + r1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + data = r1.content + assert r1.status_code == 200 + cd = r1.headers.get("Content-Disposition", "") + assert "attachment" in cd + assert data and len(data) == 4096 + + # inline requested + r2 = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120) + r2.content + assert r2.status_code == 200 + cd2 = r2.headers.get("Content-Disposition", "") + assert "inline" in cd2 + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_download_chooses_existing_state_and_updates_access_time( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two state paths: if the first one disappears, + GET /content still serves from the remaining path and bumps last_access_time. + """ + scope = f"dl-first-{uuid.uuid4().hex[:6]}" + name = "first_existing_state.bin" + data = make_asset_bytes(name, 3072) + + # Upload -> path1 + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + path1 = base / get_asset_filename(a["asset_hash"], ".bin") + assert path1.exists() + + # Seed path2 by copying, then scan to dedupe into a second state + path2 = base / "alt" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + run_scan_and_wait(root) + + # Remove path1 so server must fall back to path2 + path1.unlink() + + # last_access_time before + rg0 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d0 = rg0.json() + assert rg0.status_code == 200, d0 + ts0 = d0.get("last_access_time") + + time.sleep(0.05) + r = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + blob = r.content + assert r.status_code == 200 + assert blob == data # must serve from the surviving state (same bytes) + + rg1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = rg1.json() + assert rg1.status_code == 200, d1 + ts1 = d1.get("last_access_time") + + def _parse_iso8601(s: Optional[str]) -> Optional[float]: + if not s: + return None + s = s[:-1] if s.endswith("Z") else s + return datetime.fromisoformat(s).timestamp() + + t0 = _parse_iso8601(ts0) + t1 = _parse_iso8601(ts1) + assert t1 is not None + if t0 is not None: + assert t1 > t0 + + +@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True) +def test_download_missing_file_returns_404( + http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict +): + # Remove the underlying file then attempt download. + # We initialize fixture without additional tags to know exactly the asset file path. + try: + aid = seeded_asset["id"] + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = rg.json() + assert rg.status_code == 200 + asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors") + abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename + assert abs_path.exists() + abs_path.unlink() + + r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + assert r2.status_code == 404 + body = r2.json() + assert body["error"]["code"] == "FILE_NOT_FOUND" + finally: + # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. + dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + dr.content + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_download_404_if_all_states_missing( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Multi-state asset: after the last remaining on-disk file is removed, download must return 404.""" + scope = f"dl-404-{uuid.uuid4().hex[:6]}" + name = "missing_all_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload -> path1 + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p1 = base / get_asset_filename(a["asset_hash"], ".bin") + assert p1.exists() + + # Seed a second state and dedupe + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + run_scan_and_wait(root) + + # Remove first file -> download should still work via the second state + p1.unlink() + ok1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + b1 = ok1.content + assert ok1.status_code == 200 and b1 == data + + # Remove the last file -> download must 404 + p2.unlink() + r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + body = r2.json() + assert r2.status_code == 404 + assert body["error"]["code"] == "FILE_NOT_FOUND" diff --git a/tests-unit/assets_test/test_list_filter.py b/tests-unit/assets_test/test_list_filter.py new file mode 100644 index 000000000..82e109832 --- /dev/null +++ b/tests-unit/assets_test/test_list_filter.py @@ -0,0 +1,342 @@ +import time +import uuid + +import requests + + +def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"] + for n in names: + asset_factory( + n, + ["models", "checkpoints", "unit-tests", "paging"], + {"epoch": 1}, + make_asset_bytes(n, size=2048), + ) + + # name ascending for stable order + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == sorted(names)[:2] + assert b1["has_more"] is True + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == sorted(names)[2:] + assert b2["has_more"] is False + + +def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory): + a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024) + b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024) + + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "name_contains": "inc_"}, + timeout=120, + ) + body2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in body2["assets"]] + assert a["name"] in names2 + assert b["name"] in names2 + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "non-existing-tag"}, + timeout=120, + ) + body3 = r2.json() + assert r2.status_code == 200 + assert not body3["assets"] + + +def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-size"] + n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors" + asset_factory(n1, t, {}, make_asset_bytes(n1, 1024)) + asset_factory(n2, t, {}, make_asset_bytes(n2, 2048)) + asset_factory(n3, t, {}, make_asset_bytes(n3, 3072)) + + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"}, + timeout=120, + ) + b1 = r1.json() + names = [a["name"] for a in b1["assets"]] + assert names[:3] == [n1, n2, n3] + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"}, + timeout=120, + ) + b2 = r2.json() + names2 = [a["name"] for a in b2["assets"]] + assert names2[:3] == [n3, n2, n1] + + + +def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-upd"] + a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200)) + a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200)) + + # Rename the second asset to bump updated_at + rp = http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}, timeout=120) + upd = rp.json() + assert rp.status_code == 200, upd + + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == "upd_b_renamed.safetensors" + assert a1["name"] in names + + + +def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-access"] + asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100)) + time.sleep(0.02) + a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100)) + + # Touch last_access_time of b by downloading its content + time.sleep(0.02) + dl = http.get(f"{api_base}/api/assets/{a2['id']}/content", timeout=120) + assert dl.status_code == 200 + dl.content + + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == a2["name"] + + +def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-include"] + a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva")) + asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb")) + + # CSV + case-insensitive + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + assert not any("beta" in x for x in names1) + + # Repeated query params for include_tags + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-include"), + ("include_tags", "alpha"), + ] + r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a["name"] in names2 + assert not any("beta" in x for x in names2) + + # Duplicates and spaces in CSV + r3 = http.get( + api_base + "/api/assets", + params={"include_tags": " unit-tests , lf-include , alpha , alpha "}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a["name"] in names3 + + +def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-exclude"] + a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900)) + asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900)) + + # Exclude uppercase should work + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + # Repeated excludes with duplicates + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-exclude"), + ("exclude_tags", "beta"), + ("exclude_tags", "beta"), + ] + r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in b2["assets"]] + assert all("beta" not in x for x in names2) + + +def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-name"] + a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800)) + a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800)) + + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a1["name"] in names1 + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a1["name"] in names2 + + r3 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a2["name"] in names3 + + +def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"] + asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600)) + asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600)) + asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600)) + + # Offset far beyond total + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + assert not b1["assets"] + assert b1["has_more"] is False + + # Boundary large limit (<=500 is valid) + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + assert len(b2["assets"]) == 3 + assert b2["has_more"] is False + + +def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base): + r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_QUERY" + + +def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str): + # limit too small + r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + # bad metadata JSON + r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_QUERY" + + +def test_list_assets_name_contains_literal_underscore( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'name_contains' must treat '_' literally, not as a SQL wildcard. + We create: + - foo_bar.safetensors (should match) + - fooxbar.safetensors (must NOT match if '_' is escaped) + - foobar.safetensors (must NOT match) + """ + scope = f"lf-underscore-{uuid.uuid4().hex[:6]}" + tags = ["models", "checkpoints", "unit-tests", scope] + + a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700)) + b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700)) + c = asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700)) + + r = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200, body + names = [x["name"] for x in body["assets"]] + assert a["name"] in names, f"Expected literal underscore match to include {a['name']}" + assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'" + assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'" + assert body["total"] == 1 diff --git a/tests-unit/assets_test/test_metadata_filters.py b/tests-unit/assets_test/test_metadata_filters.py new file mode 100644 index 000000000..20285a3b3 --- /dev/null +++ b/tests-unit/assets_test/test_metadata_filters.py @@ -0,0 +1,395 @@ +import json + + +def test_meta_and_across_keys_and_types( + http, api_base: str, asset_factory, make_asset_bytes +): + name = "mf_and_mix.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-and"] + meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + asset_factory(name, tags, meta, make_asset_bytes(name, 4096)) + + # All keys must match (AND semantics) + f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_ok), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names = [a["name"] for a in b1["assets"]] + assert name in names + + # One key mismatched -> no result + f_bad = {"purpose": "mix", "epoch": 2, "active": True} + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_bad), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + assert not b2["assets"] + + +def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes): + name = "mf_types.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-types"] + meta = {"epoch": 1, "active": True} + asset_factory(name, tags, meta, make_asset_bytes(name)) + + # int filter matches numeric + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": 1}), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # string "1" must NOT match numeric 1 + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": "1"}), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + # bool True matches, string "true" must NOT match + r3 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": True}), + }, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"]) + + r4 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": "true"}), + }, + timeout=120, + ) + b4 = r4.json() + assert r4.status_code == 200 and not b4["assets"] + + +def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_scalars.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-list"] + meta = {"flags": ["red", "green"]} + asset_factory(name, tags, meta, make_asset_bytes(name, 3000)) + + # Any-of should match because "green" is present + filt_ok = {"flags": ["blue", "green"]} + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # None of provided flags present -> no match + filt_miss = {"flags": ["blue", "yellow"]} + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + # Duplicates in list should not break matching + filt_dup = {"flags": ["green", "green", "green"]} + r3 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"]) + + +def test_meta_none_semantics_missing_or_null_and_any_of_with_none( + http, api_base, asset_factory, make_asset_bytes +): + # a1: key missing; a2: explicit null; a3: concrete value + t = ["models", "checkpoints", "unit-tests", "mf-none"] + a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1")) + a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2")) + a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3")) + + # Filter {maybe: None} must match a1 and a2, not a3 + filt = {"maybe": None} + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + got = [a["name"] for a in b1["assets"]] + assert a1["name"] in got and a2["name"] in got and a3["name"] not in got + + # Any-of with None should include missing/null plus value matches + filt_any = {"maybe": [None, "x"]} + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + got2 = [a["name"] for a in b2["assets"]] + assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2 + + +def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes): + name = "mf_nested_json.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-nested"] + cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}} + asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200)) + + # Exact JSON object equality (same structure) + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": cfg}), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Different JSON object should not match + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_objects.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-objlist"] + transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}] + asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048)) + + # Any-of for list of objects should match when one element equals the filter object + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Non-matching object -> no match + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes): + name = "mf_keys_unicode.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-keys"] + meta = { + "weird.key": "v1", + "path/like": 7, + "with:colon": True, + "ключ": "значение", + "emoji": "🐍", + } + asset_factory(name, tags, meta, make_asset_bytes(name, 1500)) + + # Match all the special keys + filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"} + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Unicode key match + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and any(a["name"] == name for a in b2["assets"]) + + +def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"] + a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025)) + a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026)) + + # count == 0 must match only a0 + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [a["name"] for a in b1["assets"]] + assert a0["name"] in names1 and a1["name"] not in names1 + + # Any-of list of booleans: True matches second asset + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and any(a["name"] == a1["name"] for a in b2["assets"]) + + +def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes): + name = "mf_mixed_list.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-mixed"] + meta = {"mix": ["1", 1, True, None]} + asset_factory(name, tags, meta, make_asset_bytes(name, 1999)) + + # Should match because 1 is present + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Should NOT match for False + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes): + # Use a unique scope tag to avoid interference + t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"] + x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua")) + y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub")) + + # Filtering by unknown key with None should return both (missing key OR null) + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names = {a["name"] for a in b1["assets"]} + assert x["name"] in names and y["name"] in names + + # Filtering by unknown key with concrete value should return none + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes): + # alpha matches epoch=1; beta has epoch=2 + a = asset_factory( + "mf_tag_alpha.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "alpha"], + {"epoch": 1}, + make_asset_bytes("alpha"), + ) + b = asset_factory( + "mf_tag_beta.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "beta"], + {"epoch": 2}, + make_asset_bytes("beta"), + ) + + params = { + "include_tags": "unit-tests,mf-tag,alpha", + "exclude_tags": "beta", + "name_contains": "mf_tag_", + "metadata_filter": json.dumps({"epoch": 1}), + } + r = http.get(api_base + "/api/assets", params=params, timeout=120) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + +def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes): + # Three assets in same scope with different sizes and a common filter key + t = ["models", "checkpoints", "unit-tests", "mf-sort"] + n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors" + asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024)) + asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048)) + asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072)) + + # Sort by size ascending with paging + q = { + "include_tags": "unit-tests,mf-sort", + "metadata_filter": json.dumps({"group": "g"}), + "sort": "size", "order": "asc", "limit": "2", + } + r1 = http.get(api_base + "/api/assets", params=q, timeout=120) + b1 = r1.json() + assert r1.status_code == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == [n1, n2] + assert b1["has_more"] is True + + q2 = {**q, "offset": "2"} + r2 = http.get(api_base + "/api/assets", params=q2, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == [n3] + assert b2["has_more"] is False diff --git a/tests-unit/assets_test/test_prune_orphaned_assets.py b/tests-unit/assets_test/test_prune_orphaned_assets.py new file mode 100644 index 000000000..f602e5a77 --- /dev/null +++ b/tests-unit/assets_test/test_prune_orphaned_assets.py @@ -0,0 +1,141 @@ +import uuid +from pathlib import Path + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + +@pytest.fixture +def create_seed_file(comfy_tmp_base_dir: Path): + """Create a file on disk that will become a seed asset after sync.""" + created: list[Path] = [] + + def _create(root: str, scope: str, name: str | None = None, data: bytes = b"TEST") -> Path: + name = name or f"seed_{uuid.uuid4().hex[:8]}.bin" + path = comfy_tmp_base_dir / root / "unit-tests" / scope / name + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + created.append(path) + return path + + yield _create + + for p in created: + p.unlink(missing_ok=True) + + +@pytest.fixture +def find_asset(http: requests.Session, api_base: str): + """Query API for assets matching scope and optional name.""" + def _find(scope: str, name: str | None = None) -> list[dict]: + params = {"include_tags": f"unit-tests,{scope}"} + if name: + params["name_contains"] = name + r = http.get(f"{api_base}/api/assets", params=params, timeout=120) + assert r.status_code == 200 + assets = r.json().get("assets", []) + if name: + return [a for a in assets if a.get("name") == name] + return assets + + return _find + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_orphaned_seed_asset_is_pruned( + root: str, + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, +): + """Seed asset with deleted file is removed; with file present, it survives.""" + scope = f"prune-{uuid.uuid4().hex[:6]}" + fp = create_seed_file(root, scope) + name = fp.name + + trigger_sync_seed_assets(http, api_base) + assert find_asset(scope, name), "Seed asset should exist" + + fp.unlink() + trigger_sync_seed_assets(http, api_base) + assert not find_asset(scope, name), "Orphaned seed should be pruned" + + +def test_seed_asset_with_file_survives_prune( + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, +): + """Seed asset with file still on disk is NOT pruned.""" + scope = f"keep-{uuid.uuid4().hex[:6]}" + fp = create_seed_file("input", scope) + + trigger_sync_seed_assets(http, api_base) + trigger_sync_seed_assets(http, api_base) + + assert find_asset(scope, fp.name), "Seed with valid file should survive" + + +def test_hashed_asset_not_pruned_when_file_missing( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """Hashed assets are never deleted by prune, even without file.""" + scope = f"hashed-{uuid.uuid4().hex[:6]}" + data = make_asset_bytes("test", 2048) + a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data) + + path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin") + path.unlink() + + trigger_sync_seed_assets(http, api_base) + + r = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120) + assert r.status_code == 200, "Hashed asset should NOT be pruned" + + +def test_prune_across_multiple_roots( + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, +): + """Prune correctly handles assets across input and output roots.""" + scope = f"multi-{uuid.uuid4().hex[:6]}" + input_fp = create_seed_file("input", scope, "input.bin") + create_seed_file("output", scope, "output.bin") + + trigger_sync_seed_assets(http, api_base) + assert len(find_asset(scope)) == 2 + + input_fp.unlink() + trigger_sync_seed_assets(http, api_base) + + remaining = find_asset(scope) + assert len(remaining) == 1 + assert remaining[0]["name"] == "output.bin" + + +@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"]) +def test_special_chars_in_path_escaped_correctly( + dirname: str, + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, +): + """SQL LIKE wildcards (%, _) and spaces in paths don't cause false matches.""" + scope = f"special-{uuid.uuid4().hex[:6]}/{dirname}" + fp = create_seed_file("input", scope) + + trigger_sync_seed_assets(http, api_base) + trigger_sync_seed_assets(http, api_base) + + assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive" diff --git a/tests-unit/assets_test/test_tags.py b/tests-unit/assets_test/test_tags.py new file mode 100644 index 000000000..6b1047802 --- /dev/null +++ b/tests-unit/assets_test/test_tags.py @@ -0,0 +1,225 @@ +import json +import uuid + +import requests + + +def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict): + # Include zero-usage tags by default + r1 = http.get(api_base + "/api/tags", params={"limit": "50"}, timeout=120) + body1 = r1.json() + assert r1.status_code == 200 + names = [t["name"] for t in body1["tags"]] + # A few system tags from migration should exist: + assert "models" in names + assert "checkpoints" in names + + # Only used tags before we add anything new from this test cycle + r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120) + body2 = r2.json() + assert r2.status_code == 200 + # We already seeded one asset via fixture, so used tags must be non-empty + used_names = [t["name"] for t in body2["tags"]] + assert "models" in used_names + assert "checkpoints" in used_names + + # Prefix filter should refine the list + r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120) + b3 = r3.json() + assert r3.status_code == 200 + names3 = [t["name"] for t in b3["tags"]] + assert "unit-tests" in names3 + assert "models" not in names3 # filtered out by prefix + + # Order by name ascending should be stable + r4 = http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}, timeout=120) + b4 = r4.json() + assert r4.status_code == 200 + names4 = [t["name"] for t in b4["tags"]] + assert names4 == sorted(names4) + + +def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + # Baseline: system tags exist when include_zero (default) is true + r1 = http.get(api_base + "/api/tags", params={"limit": "500"}, timeout=120) + body1 = r1.json() + assert r1.status_code == 200 + names = [t["name"] for t in body1["tags"]] + assert "models" in names and "checkpoints" in names + + # Create a short-lived asset under input with a unique custom tag + scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}" + custom_tag = f"temp-{uuid.uuid4().hex[:8]}" + name = "tag_seed.bin" + _asset = asset_factory( + name, + ["input", "unit-tests", scope, custom_tag], + {}, + make_asset_bytes(name, 512), + ) + + # While the asset exists, the custom tag must appear when include_zero=false + r2 = http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + timeout=120, + ) + body2 = r2.json() + assert r2.status_code == 200 + used_names = [t["name"] for t in body2["tags"]] + assert custom_tag in used_names + + # Delete the asset so the tag usage drops to zero + rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120) + assert rd.status_code == 204 + + # Now the custom tag must not be returned when include_zero=false + r3 = http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + timeout=120, + ) + body3 = r3.json() + assert r3.status_code == 200 + names_after = [t["name"] for t in body3["tags"]] + assert custom_tag not in names_after + assert not names_after # filtered view should be empty now + + +def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add tags with duplicates and mixed case + payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]} + r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120) + b1 = r1.json() + assert r1.status_code == 200, b1 + # normalized, deduplicated; 'unit-tests' was already present from the seed + assert set(b1["added"]) == {"newtag", "beta"} + assert set(b1["already_present"]) == {"unit-tests"} + assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] + + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + g = rg.json() + assert rg.status_code == 200 + tags_now = set(g["tags"]) + assert {"newtag", "beta"}.issubset(tags_now) + + # Remove a tag and a non-existent tag + payload_del = {"tags": ["newtag", "does-not-exist"]} + r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + assert set(b2["removed"]) == {"newtag"} + assert set(b2["not_present"]) == {"does-not-exist"} + + # Verify remaining tags after deletion + rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + g2 = rg2.json() + assert rg2.status_code == 200 + tags_later = set(g2["tags"]) + assert "newtag" not in tags_later + assert "beta" in tags_later # still present + + +def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + h = seeded_asset["asset_hash"] + + # Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1) + r_add = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}, timeout=120) + add_body = r_add.json() + assert r_add.status_code == 200, add_body + + # Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'. + payload = { + "hash": h, + "name": "order_only_bbb.safetensors", + "tags": ["input", "unit-tests", "orderbbb"], + "user_metadata": {}, + } + r_copy = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + copy_body = r_copy.json() + assert r_copy.status_code == 201, copy_body + + # 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa' + # because it has higher usage (2 vs 1). + r1 = http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}, timeout=120) + b1 = r1.json() + assert r1.status_code == 200, b1 + names1 = [t["name"] for t in b1["tags"]] + counts1 = {t["name"]: t["count"] for t in b1["tags"]} + # Both must be present within the prefix subset + assert "orderaaa" in names1 and "orderbbb" in names1 + # Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1 + assert counts1["orderbbb"] >= counts1["orderaaa"] + # And with count_desc, 'orderbbb' appears earlier than 'orderaaa' + assert names1.index("orderbbb") < names1.index("orderaaa") + + # 2) name_asc: lexical order should flip the relative order + r2 = http.get( + api_base + "/api/tags", + params={"prefix": "order", "include_zero": "false", "order": "name_asc"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200, b2 + names2 = [t["name"] for t in b2["tags"]] + assert "orderaaa" in names2 and "orderbbb" in names2 + assert names2.index("orderaaa") < names2.index("orderbbb") + + # 3) invalid limit rejected (existing negative case retained) + r3 = http.get(api_base + "/api/tags", params={"limit": "1001"}, timeout=120) + b3 = r3.json() + assert r3.status_code == 400 + assert b3["error"]["code"] == "INVALID_QUERY" + + +def test_tags_endpoints_invalid_bodies(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add with empty list + r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Remove with wrong type + r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}, timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_BODY" + + # metadata_filter provided as JSON array should be rejected (must be object) + r3 = http.get( + api_base + "/api/assets", + params={"metadata_filter": json.dumps([{"x": 1}])}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 400 + assert b3["error"]["code"] == "INVALID_QUERY" + + +def test_tags_prefix_treats_underscore_literal( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'prefix' for /api/tags must treat '_' literally, not as a wildcard.""" + base = f"pref_{uuid.uuid4().hex[:6]}" + tag_ok = f"{base}_ok" # should match prefix=f"{base}_" + tag_bad = f"{base}xok" # must NOT match if '_' is escaped + scope = f"tags-underscore-{uuid.uuid4().hex[:6]}" + + asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512)) + asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512)) + + r = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}, timeout=120) + body = r.json() + assert r.status_code == 200, body + names = [t["name"] for t in body["tags"]] + assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'" + assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard" + assert body["total"] == 1 diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py new file mode 100644 index 000000000..137d7391a --- /dev/null +++ b/tests-unit/assets_test/test_uploads.py @@ -0,0 +1,281 @@ +import json +import uuid +from concurrent.futures import ThreadPoolExecutor + +import requests +import pytest + + +def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes): + name = "dup_a.safetensors" + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "dup"} + data = make_asset_bytes(name) + files = {"file": (name, data, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} + r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a1 = r1.json() + assert r1.status_code == 201, a1 + assert a1["created_new"] is True + + # Second upload with the same data and name should return created_new == False and the same asset + files = {"file": (name, data, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} + r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a2 = r2.json() + assert r2.status_code == 200, a2 + assert a2["created_new"] is False + assert a2["asset_hash"] == a1["asset_hash"] + assert a2["id"] == a1["id"] # old reference + + # Third upload with the same data but new name should return created_new == False and the new AssetReference + files = {"file": (name, data, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)} + r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a3 = r2.json() + assert r2.status_code == 200, a3 + assert a3["created_new"] is False + assert a3["asset_hash"] == a1["asset_hash"] + assert a3["id"] != a1["id"] # old reference + + +def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): + # Seed a small file first + name = "fastpath_seed.safetensors" + tags = ["models", "checkpoints", "unit-tests"] + meta = {} + files = {"file": (name, b"B" * 1024, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} + r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + b1 = r1.json() + assert r1.status_code == 201, b1 + h = b1["asset_hash"] + + # Now POST /api/assets with only hash and no file + files = [ + ("hash", (None, h)), + ("tags", (None, json.dumps(tags))), + ("name", (None, "fastpath_copy.safetensors")), + ("user_metadata", (None, json.dumps({"purpose": "copy"}))), + ] + r2 = http.post(api_base + "/api/assets", files=files, timeout=120) + b2 = r2.json() + assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +def test_upload_fastpath_with_known_hash_and_file( + http: requests.Session, api_base: str +): + # Seed + files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})} + r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + b1 = r1.json() + assert r1.status_code == 201, b1 + h = b1["asset_hash"] + + # Send both file and hash of existing content -> server must drain file and create from hash (200) + files = {"file": ("ignored.bin", b"ignored" * 10, "application/octet-stream")} + form = {"hash": h, "tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "copy_from_hash.safetensors", "user_metadata": json.dumps({})} + r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + b2 = r2.json() + assert r2.status_code == 200, b2 + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str): + data = [ + ("tags", "models,checkpoints"), + ("tags", json.dumps(["unit-tests", "alpha"])), + ("name", "merge.safetensors"), + ("user_metadata", json.dumps({"u": 1})), + ] + files = {"file": ("merge.safetensors", b"B" * 256, "application/octet-stream")} + r1 = http.post(api_base + "/api/assets", data=data, files=files, timeout=120) + created = r1.json() + assert r1.status_code in (200, 201), created + aid = created["id"] + + # Verify all tags are present on the resource + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = rg.json() + assert rg.status_code == 200, detail + tags = set(detail["tags"]) + assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_concurrent_upload_identical_bytes_different_names( + root: str, + http: requests.Session, + api_base: str, + make_asset_bytes, +): + """ + Two concurrent uploads of identical bytes but different names. + Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + """ + scope = f"concupload-{uuid.uuid4().hex[:6]}" + name1, name2 = "cu_a.bin", "cu_b.bin" + data = make_asset_bytes("concurrent", 4096) + tags = [root, "unit-tests", scope] + + def _do_upload(args): + url, form_data, files_data = args + with requests.Session() as s: + return s.post(url, data=form_data, files=files_data, timeout=120) + + url = api_base + "/api/assets" + form1 = {"tags": json.dumps(tags), "name": name1, "user_metadata": json.dumps({})} + files1 = {"file": (name1, data, "application/octet-stream")} + form2 = {"tags": json.dumps(tags), "name": name2, "user_metadata": json.dumps({})} + files2 = {"file": (name2, data, "application/octet-stream")} + + with ThreadPoolExecutor(max_workers=2) as executor: + futures = list(executor.map(_do_upload, [(url, form1, files1), (url, form2, files2)])) + r1, r2 = futures + + b1, b2 = r1.json(), r2.json() + assert r1.status_code in (200, 201), b1 + assert r2.status_code in (200, 201), b2 + assert b1["asset_hash"] == b2["asset_hash"] + assert b1["id"] != b2["id"] + + created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))]) + assert created_flags == [False, True] + + rl = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "sort": "name"}, + timeout=120, + ) + bl = rl.json() + assert rl.status_code == 200, bl + names = [a["name"] for a in bl.get("assets", [])] + assert set([name1, name2]).issubset(names) + + +def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str): + payload = { + "hash": "blake3:" + "0" * 64, + "name": "nonexistent.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120) + body = r.json() + assert r.status_code == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +def test_upload_zero_byte_rejected(http: requests.Session, api_base: str): + files = {"file": ("empty.safetensors", b"", "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "EMPTY_UPLOAD" + + +def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str): + files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")} + form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str): + files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_requires_multipart(http: requests.Session, api_base: str): + r = http.post(api_base + "/api/assets", json={"foo": "bar"}, timeout=120) + body = r.json() + assert r.status_code == 415 + assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE" + + +def test_upload_missing_file_and_hash(http: requests.Session, api_base: str): + files = [ + ("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))), + ("name", (None, "x.safetensors")), + ] + r = http.post(api_base + "/api/assets", files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "MISSING_FILE" + + +def test_upload_models_unknown_category(http: requests.Session, api_base: str): + files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")} + form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + assert body["error"]["message"].startswith("unknown models category") + + +def test_upload_models_requires_category(http: requests.Session, api_base: str): + files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")} + form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_tags_traversal_guard(http: requests.Session, api_base: str): + files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_duplicate_upload_same_display_name_does_not_clobber( + root: str, + http: requests.Session, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Two uploads use the same tags and the same display name but different bytes. + With hash-based filenames, they must NOT overwrite each other. Both assets + remain accessible and serve their original content. + """ + scope = f"dup-path-{uuid.uuid4().hex[:6]}" + display_name = "same_display.bin" + + d1 = make_asset_bytes(scope + "-v1", 1536) + d2 = make_asset_bytes(scope + "-v2", 2048) + tags = [root, "unit-tests", scope] + + first = asset_factory(display_name, tags, {}, d1) + second = asset_factory(display_name, tags, {}, d2) + + assert first["id"] != second["id"] + assert first["asset_hash"] != second["asset_hash"] # different content + assert first["name"] == second["name"] == display_name + + # Both must be independently retrievable + r1 = http.get(f"{api_base}/api/assets/{first['id']}/content", timeout=120) + b1 = r1.content + assert r1.status_code == 200 + assert b1 == d1 + r2 = http.get(f"{api_base}/api/assets/{second['id']}/content", timeout=120) + b2 = r2.content + assert r2.status_code == 200 + assert b2 == d2 diff --git a/tests-unit/comfy_api_nodes_test/mapper_utils_test.py b/tests-unit/comfy_api_nodes_test/mapper_utils_test.py deleted file mode 100644 index 69488f691..000000000 --- a/tests-unit/comfy_api_nodes_test/mapper_utils_test.py +++ /dev/null @@ -1,297 +0,0 @@ -from typing import Optional -from enum import Enum - -from pydantic import BaseModel, Field - -from comfy.comfy_types.node_typing import IO -from comfy_api_nodes.mapper_utils import model_field_to_node_input - - -def test_model_field_to_float_input(): - """Tests mapping a float field with constraints.""" - - class ModelWithFloatField(BaseModel): - cfg_scale: Optional[float] = Field( - default=0.5, - description="Flexibility in video generation", - ge=0.0, - le=1.0, - multiple_of=0.001, - ) - - expected_output = ( - IO.FLOAT, - { - "default": 0.5, - "tooltip": "Flexibility in video generation", - "min": 0.0, - "max": 1.0, - "step": 0.001, - }, - ) - - actual_output = model_field_to_node_input( - IO.FLOAT, ModelWithFloatField, "cfg_scale" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_float_input_no_constraints(): - """Tests mapping a float field with no constraints.""" - - class ModelWithFloatField(BaseModel): - cfg_scale: Optional[float] = Field(default=0.5) - - expected_output = ( - IO.FLOAT, - { - "default": 0.5, - }, - ) - - actual_output = model_field_to_node_input( - IO.FLOAT, ModelWithFloatField, "cfg_scale" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_int_input(): - """Tests mapping an int field with constraints.""" - - class ModelWithIntField(BaseModel): - num_frames: Optional[int] = Field( - default=10, - description="Number of frames to generate", - ge=1, - le=100, - multiple_of=1, - ) - - expected_output = ( - IO.INT, - { - "default": 10, - "tooltip": "Number of frames to generate", - "min": 1, - "max": 100, - "step": 1, - }, - ) - - actual_output = model_field_to_node_input(IO.INT, ModelWithIntField, "num_frames") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_string_input(): - """Tests mapping a string field.""" - - class ModelWithStringField(BaseModel): - prompt: Optional[str] = Field( - default="A beautiful sunset over a calm ocean", - description="A prompt for the video generation", - ) - - expected_output = ( - IO.STRING, - { - "default": "A beautiful sunset over a calm ocean", - "tooltip": "A prompt for the video generation", - }, - ) - - actual_output = model_field_to_node_input(IO.STRING, ModelWithStringField, "prompt") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_string_input_multiline(): - """Tests mapping a string field.""" - - class ModelWithStringField(BaseModel): - prompt: Optional[str] = Field( - default="A beautiful sunset over a calm ocean", - description="A prompt for the video generation", - ) - - expected_output = ( - IO.STRING, - { - "default": "A beautiful sunset over a calm ocean", - "tooltip": "A prompt for the video generation", - "multiline": True, - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithStringField, "prompt", multiline=True - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_combo_input(): - """Tests mapping a combo field.""" - - class MockEnum(str, Enum): - option_1 = "option 1" - option_2 = "option 2" - option_3 = "option 3" - - class ModelWithComboField(BaseModel): - model_name: Optional[MockEnum] = Field("option 1", description="Model Name") - - expected_output = ( - IO.COMBO, - { - "options": ["option 1", "option 2", "option 3"], - "default": "option 1", - "tooltip": "Model Name", - }, - ) - - actual_output = model_field_to_node_input( - IO.COMBO, ModelWithComboField, "model_name", enum_type=MockEnum - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_combo_input_no_options(): - """Tests mapping a combo field with no options.""" - - class ModelWithComboField(BaseModel): - model_name: Optional[str] = Field(description="Model Name") - - expected_output = ( - IO.COMBO, - { - "tooltip": "Model Name", - }, - ) - - actual_output = model_field_to_node_input( - IO.COMBO, ModelWithComboField, "model_name" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_image_input(): - """Tests mapping an image field.""" - - class ModelWithImageField(BaseModel): - image: Optional[str] = Field( - default=None, - description="An image for the video generation", - ) - - expected_output = ( - IO.IMAGE, - { - "default": None, - "tooltip": "An image for the video generation", - }, - ) - - actual_output = model_field_to_node_input(IO.IMAGE, ModelWithImageField, "image") - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_description(): - """Tests mapping a field with no description.""" - - class ModelWithNoDescriptionField(BaseModel): - field: Optional[str] = Field(default="default value") - - expected_output = ( - IO.STRING, - { - "default": "default value", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoDescriptionField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_default(): - """Tests mapping a field with no default.""" - - class ModelWithNoDefaultField(BaseModel): - field: Optional[str] = Field(description="A field with no default") - - expected_output = ( - IO.STRING, - { - "tooltip": "A field with no default", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoDefaultField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_no_metadata(): - """Tests mapping a field with no metadata or properties defined on the schema.""" - - class ModelWithNoMetadataField(BaseModel): - field: Optional[str] = Field() - - expected_output = ( - IO.STRING, - {}, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoMetadataField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] - - -def test_model_field_to_node_input_default_is_none(): - """ - Tests mapping a field with a default of `None`. - I.e., the default field should be included as the schema explicitly sets it to `None`. - """ - - class ModelWithNoneDefaultField(BaseModel): - field: Optional[str] = Field( - default=None, description="A field with a default of None" - ) - - expected_output = ( - IO.STRING, - { - "default": None, - "tooltip": "A field with a default of None", - }, - ) - - actual_output = model_field_to_node_input( - IO.STRING, ModelWithNoneDefaultField, "field" - ) - - assert actual_output[0] == expected_output[0] - assert actual_output[1] == expected_output[1] diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 3a54941e6..7c740491d 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -103,18 +103,18 @@ class TestMixedPrecisionOps(unittest.TestCase): # Verify weights are wrapped in QuantizedTensor self.assertIsInstance(model.layer1.weight, QuantizedTensor) - self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Layer 2 should NOT be quantized self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) # Layer 3 should be quantized self.assertIsInstance(model.layer3.weight, QuantizedTensor) - self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") + self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Verify scales were loaded - self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) - self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + self.assertEqual(model.layer1.weight._params.scale.item(), 2.0) + self.assertEqual(model.layer3.weight._params.scale.item(), 1.5) # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) @@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase): state_dict2 = model.state_dict() # Verify layer1.weight is a QuantizedTensor with scale preserved - self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) - self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) - self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") + self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8))) + self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0) + self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") # Verify non-quantized layers are standard tensors self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py deleted file mode 100644 index 9cb54ede8..000000000 --- a/tests-unit/comfy_quant/test_quant_registry.py +++ /dev/null @@ -1,190 +0,0 @@ -import unittest -import torch -import sys -import os - -# Add comfy to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) - -def has_gpu(): - return torch.cuda.is_available() - -from comfy.cli_args import args -if not has_gpu(): - args.cpu = True - -from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout - - -class TestQuantizedTensor(unittest.TestCase): - """Test the QuantizedTensor subclass with FP8 layout""" - - def test_creation(self): - """Test creating a QuantizedTensor with TensorCoreFP8Layout""" - fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(2.0) - layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.shape, (256, 128)) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt._layout_params['scale'], scale) - self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) - self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") - - def test_dequantize(self): - """Test explicit dequantization""" - - fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(3.0) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - dequantized = qt.dequantize() - - self.assertEqual(dequantized.dtype, torch.float32) - self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) - - def test_from_float(self): - """Test creating QuantizedTensor from float tensor""" - float_tensor = torch.randn(64, 32, dtype=torch.float32) - scale = torch.tensor(1.5) - - qt = QuantizedTensor.from_float( - float_tensor, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertIsInstance(qt, QuantizedTensor) - self.assertEqual(qt.dtype, torch.float8_e4m3fn) - self.assertEqual(qt.shape, (64, 32)) - - # Verify dequantization gives approximately original values - dequantized = qt.dequantize() - mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() - self.assertLess(mean_rel_error, 0.1) - - -class TestGenericUtilities(unittest.TestCase): - """Test generic utility operations""" - - def test_detach(self): - """Test detach operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Detach should return a new QuantizedTensor - qt_detached = qt.detach() - - self.assertIsInstance(qt_detached, QuantizedTensor) - self.assertEqual(qt_detached.shape, qt.shape) - self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") - - def test_clone(self): - """Test clone operation on quantized tensor""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Clone should return a new QuantizedTensor - qt_cloned = qt.clone() - - self.assertIsInstance(qt_cloned, QuantizedTensor) - self.assertEqual(qt_cloned.shape, qt.shape) - self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") - - # Verify it's a deep copy - self.assertIsNot(qt_cloned._qdata, qt._qdata) - - @unittest.skipUnless(has_gpu(), "GPU not available") - def test_to_device(self): - """Test device transfer""" - fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) - scale = torch.tensor(1.5) - layout_params = {'scale': scale, 'orig_dtype': torch.float32} - qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) - - # Moving to same device should work (CPU to CPU) - qt_cpu = qt.to('cpu') - - self.assertIsInstance(qt_cpu, QuantizedTensor) - self.assertEqual(qt_cpu.device.type, 'cpu') - self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') - - -class TestTensorCoreFP8Layout(unittest.TestCase): - """Test the TensorCoreFP8Layout implementation""" - - def test_quantize(self): - """Test quantization method""" - float_tensor = torch.randn(32, 64, dtype=torch.float32) - scale = torch.tensor(1.5) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - self.assertEqual(qdata.dtype, torch.float8_e4m3fn) - self.assertEqual(qdata.shape, float_tensor.shape) - self.assertIn('scale', layout_params) - self.assertIn('orig_dtype', layout_params) - self.assertEqual(layout_params['orig_dtype'], torch.float32) - - def test_dequantize(self): - """Test dequantization method""" - float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 - scale = torch.tensor(1.0) - - qdata, layout_params = TensorCoreFP8Layout.quantize( - float_tensor, - scale=scale, - dtype=torch.float8_e4m3fn - ) - - dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) - - # Should approximately match original - self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) - - -class TestFallbackMechanism(unittest.TestCase): - """Test fallback for unsupported operations""" - - def test_unsupported_op_dequantizes(self): - """Test that unsupported operations fall back to dequantization""" - # Set seed for reproducibility - torch.manual_seed(42) - - # Create quantized tensor - a_fp32 = torch.randn(10, 20, dtype=torch.float32) - scale = torch.tensor(1.0) - a_q = QuantizedTensor.from_float( - a_fp32, - "TensorCoreFP8Layout", - scale=scale, - dtype=torch.float8_e4m3fn - ) - - # Call an operation that doesn't have a registered handler - # For example, torch.abs - result = torch.abs(a_q) - - # Should work via fallback (dequantize → abs → return) - self.assertNotIsInstance(result, QuantizedTensor) - expected = torch.abs(a_fp32) - # FP8 introduces quantization error, so use loose tolerance - mean_error = (result - expected).abs().mean() - self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 3a6790ee0..2355b8000 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -2,3 +2,4 @@ pytest>=7.8.0 pytest-aiohttp pytest-asyncio websocket-client +blake3 diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py index 918c8080a..4d2f9ed36 100644 --- a/tests/execution/test_jobs.py +++ b/tests/execution/test_jobs.py @@ -19,6 +19,7 @@ class TestJobStatus: assert JobStatus.IN_PROGRESS == 'in_progress' assert JobStatus.COMPLETED == 'completed' assert JobStatus.FAILED == 'failed' + assert JobStatus.CANCELLED == 'cancelled' def test_all_contains_all_statuses(self): """ALL should contain all status values.""" @@ -26,7 +27,8 @@ class TestJobStatus: assert JobStatus.IN_PROGRESS in JobStatus.ALL assert JobStatus.COMPLETED in JobStatus.ALL assert JobStatus.FAILED in JobStatus.ALL - assert len(JobStatus.ALL) == 4 + assert JobStatus.CANCELLED in JobStatus.ALL + assert len(JobStatus.ALL) == 5 class TestIsPreviewable: @@ -336,6 +338,40 @@ class TestNormalizeHistoryItem: assert job['execution_error']['node_type'] == 'KSampler' assert job['execution_error']['exception_message'] == 'CUDA out of memory' + def test_cancelled_job(self): + """Cancelled/interrupted history item should have cancelled status.""" + history_item = { + 'prompt': ( + 5, + 'prompt-cancelled', + {'nodes': {}}, + {'create_time': 1234567890000}, + ['node1'], + ), + 'status': { + 'status_str': 'error', + 'completed': False, + 'messages': [ + ('execution_start', {'prompt_id': 'prompt-cancelled', 'timestamp': 1234567890500}), + ('execution_interrupted', { + 'prompt_id': 'prompt-cancelled', + 'node_id': '5', + 'node_type': 'KSampler', + 'executed': ['1', '2', '3'], + 'timestamp': 1234567891000, + }) + ] + }, + 'outputs': {}, + } + + job = normalize_history_item('prompt-cancelled', history_item) + assert job['status'] == 'cancelled' + assert job['execution_start_time'] == 1234567890500 + assert job['execution_end_time'] == 1234567891000 + # Cancelled jobs should not have execution_error set + assert 'execution_error' not in job + def test_include_outputs(self): """When include_outputs=True, should include full output data.""" history_item = {