diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat index ed00583b6..4501ef9a1 100644 --- a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -1,3 +1,3 @@ ..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat index 4898a424f..6487ac7ce 100755 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat index 32611e4af..01c5bb33b 100644 --- a/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat +++ b/.ci/windows_nvidia_base_files/run_nvidia_gpu_fast_fp16_accumulation.bat @@ -1,3 +1,3 @@ .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation -echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. If you get a c10.dll error you need to install vc redist that you can find: https://aka.ms/vc14/vc_redist.x64.exe pause diff --git a/.coderabbit.yaml b/.coderabbit.yaml new file mode 100644 index 000000000..0d1e49270 --- /dev/null +++ b/.coderabbit.yaml @@ -0,0 +1,127 @@ +# yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json +language: "en-US" +early_access: false +tone_instructions: "Only comment on issues introduced by this PR's changes. Do not flag pre-existing problems in moved, re-indented, or reformatted code." + +reviews: + profile: "chill" + request_changes_workflow: false + high_level_summary: false + poem: false + review_status: false + review_details: false + commit_status: true + collapse_walkthrough: true + changed_files_summary: false + sequence_diagrams: false + estimate_code_review_effort: false + assess_linked_issues: false + related_issues: false + related_prs: false + suggested_labels: false + auto_apply_labels: false + suggested_reviewers: false + auto_assign_reviewers: false + in_progress_fortune: false + enable_prompt_for_ai_agents: true + + path_filters: + - "!comfy_api_nodes/apis/**" + - "!**/generated/*.pyi" + - "!.ci/**" + - "!script_examples/**" + - "!**/__pycache__/**" + - "!**/*.ipynb" + - "!**/*.png" + - "!**/*.bat" + + path_instructions: + - path: "**" + instructions: | + IMPORTANT: Only comment on issues directly introduced by this PR's code changes. + Do NOT flag pre-existing issues in code that was merely moved, re-indented, + de-indented, or reformatted without logic changes. If code appears in the diff + only due to whitespace or structural reformatting (e.g., removing a `with:` block), + treat it as unchanged. Contributors should not feel obligated to address + pre-existing issues outside the scope of their contribution. + - path: "comfy/**" + instructions: | + Core ML/diffusion engine. Focus on: + - Backward compatibility (breaking changes affect all custom nodes) + - Memory management and GPU resource handling + - Performance implications in hot paths + - Thread safety for concurrent execution + - path: "comfy_api_nodes/**" + instructions: | + Third-party API integration nodes. Focus on: + - No hardcoded API keys or secrets + - Proper error handling for API failures (timeouts, rate limits, auth errors) + - Correct Pydantic model usage + - Security of user data passed to external APIs + - path: "comfy_extras/**" + instructions: | + Community-contributed extra nodes. Focus on: + - Consistency with node patterns (INPUT_TYPES, RETURN_TYPES, FUNCTION, CATEGORY) + - No breaking changes to existing node interfaces + - path: "comfy_execution/**" + instructions: | + Execution engine (graph execution, caching, jobs). Focus on: + - Caching correctness + - Concurrent execution safety + - Graph validation edge cases + - path: "nodes.py" + instructions: | + Core node definitions (2500+ lines). Focus on: + - Backward compatibility of NODE_CLASS_MAPPINGS + - Consistency of INPUT_TYPES return format + - path: "alembic_db/**" + instructions: | + Database migrations. Focus on: + - Migration safety and rollback support + - Data preservation during schema changes + + auto_review: + enabled: true + auto_incremental_review: true + drafts: false + ignore_title_keywords: + - "WIP" + - "DO NOT REVIEW" + - "DO NOT MERGE" + + finishing_touches: + docstrings: + enabled: false + unit_tests: + enabled: false + + tools: + ruff: + enabled: false + pylint: + enabled: false + flake8: + enabled: false + gitleaks: + enabled: true + shellcheck: + enabled: false + markdownlint: + enabled: false + yamllint: + enabled: false + languagetool: + enabled: false + github-checks: + enabled: true + timeout_ms: 90000 + ast-grep: + essential_rules: true + +chat: + auto_reply: true + +knowledge_base: + opt_out: false + learnings: + scope: "auto" diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 6556677e0..95cc48f88 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -16,7 +16,7 @@ body: ## Very Important - Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored. + Please make sure that you post ALL your ComfyUI logs in the bug report **even if there is no crash**. Just paste everything. The startup log (everything before "To see the GUI go to: ...") contains critical information to developers trying to help. For a performance issue or crash, paste everything from "got prompt" to the end, including the crash. More is better - always. A bug report without logs will likely be ignored. - type: checkboxes id: custom-nodes-test attributes: diff --git a/.github/scripts/check-ai-co-authors.sh b/.github/scripts/check-ai-co-authors.sh new file mode 100755 index 000000000..842b1f2d8 --- /dev/null +++ b/.github/scripts/check-ai-co-authors.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash +# Checks pull request commits for AI agent Co-authored-by trailers. +# Exits non-zero when any are found and prints fix instructions. +set -euo pipefail + +base_sha="${1:?usage: check-ai-co-authors.sh }" +head_sha="${2:?usage: check-ai-co-authors.sh }" + +# Known AI coding-agent trailer patterns (case-insensitive). +# Each entry is an extended-regex fragment matched against Co-authored-by lines. +AGENT_PATTERNS=( + # Anthropic — Claude Code / Amp + 'noreply@anthropic\.com' + # Cursor + 'cursoragent@cursor\.com' + # GitHub Copilot + 'copilot-swe-agent\[bot\]' + 'copilot@github\.com' + # OpenAI Codex + 'noreply@openai\.com' + 'codex@openai\.com' + # Aider + 'aider@aider\.chat' + # Google — Gemini / Jules + 'gemini@google\.com' + 'jules@google\.com' + # Windsurf / Codeium + '@codeium\.com' + # Devin + 'devin-ai-integration\[bot\]' + 'devin@cognition\.ai' + 'devin@cognition-labs\.com' + # Amazon Q Developer + 'amazon-q-developer' + '@amazon\.com.*[Qq].[Dd]eveloper' + # Cline + 'cline-bot' + 'cline@cline\.ai' + # Continue + 'continue-agent' + 'continue@continue\.dev' + # Sourcegraph + 'noreply@sourcegraph\.com' + # Generic catch-alls for common agent name patterns + 'Co-authored-by:.*\b[Cc]laude\b' + 'Co-authored-by:.*\b[Cc]opilot\b' + 'Co-authored-by:.*\b[Cc]ursor\b' + 'Co-authored-by:.*\b[Cc]odex\b' + 'Co-authored-by:.*\b[Gg]emini\b' + 'Co-authored-by:.*\b[Aa]ider\b' + 'Co-authored-by:.*\b[Dd]evin\b' + 'Co-authored-by:.*\b[Ww]indsurf\b' + 'Co-authored-by:.*\b[Cc]line\b' + 'Co-authored-by:.*\b[Aa]mazon Q\b' + 'Co-authored-by:.*\b[Jj]ules\b' + 'Co-authored-by:.*\bOpenCode\b' +) + +# Build a single alternation regex from all patterns. +regex="" +for pattern in "${AGENT_PATTERNS[@]}"; do + if [[ -n "$regex" ]]; then + regex="${regex}|${pattern}" + else + regex="$pattern" + fi +done + +# Collect Co-authored-by lines from every commit in the PR range. +violations="" +while IFS= read -r sha; do + message="$(git log -1 --format='%B' "$sha")" + matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)" + if [[ -z "$matched_lines" ]]; then + continue + fi + + while IFS= read -r line; do + if echo "$line" | grep -iqE "$regex"; then + short="$(git log -1 --format='%h' "$sha")" + violations="${violations} ${short}: ${line}"$'\n' + fi + done <<< "$matched_lines" +done < <(git rev-list "${base_sha}..${head_sha}") + +if [[ -n "$violations" ]]; then + echo "::error::AI agent Co-authored-by trailers detected in PR commits." + echo "" + echo "The following commits contain Co-authored-by trailers from AI coding agents:" + echo "" + echo "$violations" + echo "These trailers should be removed before merging." + echo "" + echo "To fix, rewrite the commit messages with:" + echo " git rebase -i ${base_sha}" + echo "" + echo "and remove the Co-authored-by lines, then force-push your branch." + echo "" + echo "If you believe this is a false positive, please open an issue." + exit 1 +fi + +echo "No AI agent Co-authored-by trailers found." diff --git a/.github/workflows/check-ai-co-authors.yml b/.github/workflows/check-ai-co-authors.yml new file mode 100644 index 000000000..2ad9ac972 --- /dev/null +++ b/.github/workflows/check-ai-co-authors.yml @@ -0,0 +1,19 @@ +name: Check AI Co-Authors + +on: + pull_request: + branches: ['*'] + +jobs: + check-ai-co-authors: + name: Check for AI agent co-author trailers + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Check commits for AI co-author trailers + run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" diff --git a/.github/workflows/release-stable-all.yml b/.github/workflows/release-stable-all.yml index d72ece2ce..8f07a7b1c 100644 --- a/.github/workflows/release-stable-all.yml +++ b/.github/workflows/release-stable-all.yml @@ -20,7 +20,7 @@ jobs: git_tag: ${{ inputs.git_tag }} cache_tag: "cu130" python_minor: "13" - python_patch: "9" + python_patch: "11" rel_name: "nvidia" rel_extra_name: "" test_release: true @@ -65,11 +65,11 @@ jobs: contents: "write" packages: "write" pull-requests: "read" - name: "Release AMD ROCm 7.1.1" + name: "Release AMD ROCm 7.2" uses: ./.github/workflows/stable-release.yml with: git_tag: ${{ inputs.git_tag }} - cache_tag: "rocm711" + cache_tag: "rocm72" python_minor: "12" python_patch: "10" rel_name: "amd" diff --git a/.github/workflows/release-webhook.yml b/.github/workflows/release-webhook.yml index 6fceb7560..737e4c488 100644 --- a/.github/workflows/release-webhook.yml +++ b/.github/workflows/release-webhook.yml @@ -7,6 +7,8 @@ on: jobs: send-webhook: runs-on: ubuntu-latest + env: + DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }} steps: - name: Send release webhook env: @@ -106,3 +108,37 @@ jobs: --fail --silent --show-error echo "✅ Release webhook sent successfully" + + - name: Send repository dispatch to desktop + env: + DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }} + RELEASE_TAG: ${{ github.event.release.tag_name }} + RELEASE_URL: ${{ github.event.release.html_url }} + run: | + set -euo pipefail + + if [ -z "${DISPATCH_TOKEN:-}" ]; then + echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set." + exit 1 + fi + + PAYLOAD="$(jq -n \ + --arg release_tag "$RELEASE_TAG" \ + --arg release_url "$RELEASE_URL" \ + '{ + event_type: "comfyui_release_published", + client_payload: { + release_tag: $release_tag, + release_url: $release_url + } + }')" + + curl -fsSL \ + -X POST \ + -H "Accept: application/vnd.github+json" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${DISPATCH_TOKEN}" \ + https://api.github.com/repos/Comfy-Org/desktop/dispatches \ + -d "$PAYLOAD" + + echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop" diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml index 28484a9d1..f501b7b31 100644 --- a/.github/workflows/stable-release.yml +++ b/.github/workflows/stable-release.yml @@ -117,7 +117,7 @@ jobs: ./python.exe get-pip.py ./python.exe -s -m pip install ../${{ inputs.cache_tag }}_python_deps/* - grep comfyui ../ComfyUI/requirements.txt > ./requirements_comfyui.txt + grep comfy ../ComfyUI/requirements.txt > ./requirements_comfyui.txt ./python.exe -s -m pip install -r requirements_comfyui.txt rm requirements_comfyui.txt diff --git a/.github/workflows/test-build.yml b/.github/workflows/test-build.yml index 419873ad8..9160242e9 100644 --- a/.github/workflows/test-build.yml +++ b/.github/workflows/test-build.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} diff --git a/.github/workflows/test-launch.yml b/.github/workflows/test-launch.yml index fd70aff23..581c0474b 100644 --- a/.github/workflows/test-launch.yml +++ b/.github/workflows/test-launch.yml @@ -13,7 +13,7 @@ jobs: - name: Checkout ComfyUI uses: actions/checkout@v4 with: - repository: "comfyanonymous/ComfyUI" + repository: "Comfy-Org/ComfyUI" path: "ComfyUI" - uses: actions/setup-python@v4 with: @@ -32,7 +32,9 @@ jobs: working-directory: ComfyUI - name: Check for unhandled exceptions in server log run: | - if grep -qE "Exception|Error" console_output.log; then + grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': True, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" console_output.log | grep -v "Found comfy_kitchen backend triton: {'available': False, 'disabled': False, 'unavailable_reason': \"ImportError: No module named 'triton'\", 'capabilities': \[\]}" > console_output_filtered.log + cat console_output_filtered.log + if grep -qE "Exception|Error" console_output_filtered.log; then echo "Unhandled exception/error found in server log." exit 1 fi diff --git a/.github/workflows/update-ci-container.yml b/.github/workflows/update-ci-container.yml new file mode 100644 index 000000000..f7972e056 --- /dev/null +++ b/.github/workflows/update-ci-container.yml @@ -0,0 +1,59 @@ +name: "CI: Update CI Container" + +on: + release: + types: [published] + workflow_dispatch: + inputs: + version: + description: 'ComfyUI version (e.g., v0.7.0)' + required: true + type: string + +jobs: + update-ci-container: + runs-on: ubuntu-latest + # Skip pre-releases unless manually triggered + if: github.event_name == 'workflow_dispatch' || !github.event.release.prerelease + steps: + - name: Get version + id: version + run: | + if [ "${{ github.event_name }}" = "release" ]; then + VERSION="${{ github.event.release.tag_name }}" + else + VERSION="${{ inputs.version }}" + fi + echo "version=$VERSION" >> $GITHUB_OUTPUT + + - name: Checkout comfyui-ci-container + uses: actions/checkout@v4 + with: + repository: comfy-org/comfyui-ci-container + token: ${{ secrets.CI_CONTAINER_PAT }} + + - name: Check current version + id: current + run: | + CURRENT=$(grep -oP 'ARG COMFYUI_VERSION=\K.*' Dockerfile || echo "unknown") + echo "current_version=$CURRENT" >> $GITHUB_OUTPUT + + - name: Update Dockerfile + run: | + VERSION="${{ steps.version.outputs.version }}" + sed -i "s/^ARG COMFYUI_VERSION=.*/ARG COMFYUI_VERSION=${VERSION}/" Dockerfile + + - name: Create Pull Request + id: create-pr + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ secrets.CI_CONTAINER_PAT }} + branch: automation/comfyui-${{ steps.version.outputs.version }} + title: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}" + body: | + Updates ComfyUI version from `${{ steps.current.outputs.current_version }}` to `${{ steps.version.outputs.version }}` + + **Triggered by:** ${{ github.event_name == 'release' && format('[Release {0}]({1})', github.event.release.tag_name, github.event.release.html_url) || 'Manual workflow dispatch' }} + + labels: automation + commit-message: "chore: bump ComfyUI to ${{ steps.version.outputs.version }}" diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index f61ee21a2..93e01ac93 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -29,7 +29,7 @@ on: description: 'python patch version' required: true type: string - default: "9" + default: "11" # push: # branches: # - master diff --git a/.gitignore b/.gitignore index 4e8cea71e..2700ad5c2 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ extra_model_paths.yaml /.vs .vscode/ .idea/ -venv/ +venv*/ .venv/ /web/extensions/* !/web/extensions/logging.js.example diff --git a/README.md b/README.md index bae955b1b..62c4f528c 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a ## Get Started +### Local + #### [Desktop Application](https://www.comfy.org/download) - The easiest way to get started. - Available on Windows & macOS. @@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a #### [Manual Install](#manual-install-windows-linux) Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend). -## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/) -See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/). +### Cloud + +#### [Comfy Cloud](https://www.comfy.org/cloud) +- Our official paid cloud version for those who can't afford local hardware. + +## Examples +See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/). ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. @@ -108,7 +115,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - Latent previews with [TAESD](#how-to-show-high-quality-previews) - Works fully offline: core will never download anything unless you want to. -- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview). +- Optional API nodes to use paid models from external providers through the online [Comfy API](https://docs.comfy.org/tutorials/api-nodes/overview) disable with: `--disable-api-nodes` - [Config file](extra_model_paths.yaml.example) to set the search paths for models. Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/) @@ -119,6 +126,9 @@ ComfyUI follows a weekly release cycle targeting Monday but this regularly chang 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** - Releases a new stable version (e.g., v0.7.0) roughly every week. + - Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release. + - Minor versions will be used for releases off the master branch. + - Patch versions may still be used for releases on the master branch in cases where a backport would not make sense. - Commits outside of the stable release tags may be very unstable and break many custom nodes. - Serves as the foundation for the desktop release @@ -180,14 +190,12 @@ Simply download, extract with [7-Zip](https://7-zip.org) or with the windows exp If you have trouble extracting it, right click the file -> properties -> unblock -Update your Nvidia drivers if it doesn't start. +The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start. #### Alternative Downloads: [Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) -[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z). - [Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). #### How do I share models between another UI and ComfyUI? @@ -205,10 +213,12 @@ comfy install ## Manual Install (Windows, Linux) -Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies. +Python 3.14 works but some custom nodes may have issues. The free threaded variant works but some dependencies will enable the GIL so it's not fully supported. Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 +torch 2.4 and above is supported but some features and optimizations might only work on newer versions. We generally recommend using the latest major version of pytorch with the latest cuda version unless it is less than 2 weeks old. + ### Instructions: Git clone this repo. @@ -222,11 +232,11 @@ Put your VAE in: models/vae AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1``` -This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: +This is the command to install the nightly with ROCm 7.2 which might have some performance improvements: -```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1``` +```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.2``` ### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only. @@ -235,7 +245,7 @@ These have less hardware support than the builds above but they work on windows. RDNA 3 (RX 7000 series): -```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-dgpu/``` +```pip install --pre torch torchvision torchaudio --index-url https://rocm.nightlies.amd.com/v2/gfx110X-all/``` RDNA 3.5 (Strix halo/Ryzen AI Max+ 365): diff --git a/alembic_db/env.py b/alembic_db/env.py index 4d7770679..4ce37c012 100644 --- a/alembic_db/env.py +++ b/alembic_db/env.py @@ -8,7 +8,7 @@ from alembic import context config = context.config -from app.database.models import Base +from app.database.models import Base, NAMING_CONVENTION target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, @@ -51,7 +51,10 @@ def run_migrations_online() -> None: with connectable.connect() as connection: context.configure( - connection=connection, target_metadata=target_metadata + connection=connection, + target_metadata=target_metadata, + render_as_batch=True, + naming_convention=NAMING_CONVENTION, ) with context.begin_transaction(): diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py new file mode 100644 index 000000000..1e10b94dc --- /dev/null +++ b/alembic_db/versions/0001_assets.py @@ -0,0 +1,174 @@ +""" +Initial assets schema +Revision ID: 0001_assets +Revises: None +Create Date: 2025-12-10 00:00:00 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0001_assets" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ASSETS: content identity + op.create_table( + "assets", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("hash", sa.String(length=256), nullable=True), + sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), + sa.Column("mime_type", sa.String(length=255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + op.create_index("uq_assets_hash", "assets", ["hash"], unique=True) + op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) + + # ASSETS_INFO: user-visible references + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + # TAGS: normalized tag vocabulary + op.create_table( + "tags", + sa.Column("name", sa.String(length=512), primary_key=True), + sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), + sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), + ) + op.create_index("ix_tags_tag_type", "tags", ["tag_type"]) + + # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + # ASSET_CACHE_STATE: N:1 local cache rows per Asset + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) + + # Tags vocabulary + tags_table = sa.table( + "tags", + sa.column("name", sa.String(length=512)), + sa.column("tag_type", sa.String()), + ) + op.bulk_insert( + tags_table, + [ + {"name": "models", "tag_type": "system"}, + {"name": "input", "tag_type": "system"}, + {"name": "output", "tag_type": "system"}, + + {"name": "configs", "tag_type": "system"}, + {"name": "checkpoints", "tag_type": "system"}, + {"name": "loras", "tag_type": "system"}, + {"name": "vae", "tag_type": "system"}, + {"name": "text_encoders", "tag_type": "system"}, + {"name": "diffusion_models", "tag_type": "system"}, + {"name": "clip_vision", "tag_type": "system"}, + {"name": "style_models", "tag_type": "system"}, + {"name": "embeddings", "tag_type": "system"}, + {"name": "diffusers", "tag_type": "system"}, + {"name": "vae_approx", "tag_type": "system"}, + {"name": "controlnet", "tag_type": "system"}, + {"name": "gligen", "tag_type": "system"}, + {"name": "upscale_models", "tag_type": "system"}, + {"name": "hypernetworks", "tag_type": "system"}, + {"name": "photomaker", "tag_type": "system"}, + {"name": "classifiers", "tag_type": "system"}, + + {"name": "encoder", "tag_type": "system"}, + {"name": "decoder", "tag_type": "system"}, + + {"name": "missing", "tag_type": "system"}, + {"name": "rescan", "tag_type": "system"}, + ], + ) + + +def downgrade() -> None: + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_tags_tag_type", table_name="tags") + op.drop_table("tags") + + op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + op.drop_index("uq_assets_hash", table_name="assets") + op.drop_index("ix_assets_mime_type", table_name="assets") + op.drop_table("assets") diff --git a/alembic_db/versions/0002_merge_to_asset_references.py b/alembic_db/versions/0002_merge_to_asset_references.py new file mode 100644 index 000000000..1ac1b980c --- /dev/null +++ b/alembic_db/versions/0002_merge_to_asset_references.py @@ -0,0 +1,267 @@ +""" +Merge AssetInfo and AssetCacheState into unified asset_references table. + +This migration drops old tables and creates the new unified schema. +All existing data is discarded. + +Revision ID: 0002_merge_to_asset_references +Revises: 0001_assets +Create Date: 2025-02-11 +""" + +from alembic import op +import sqlalchemy as sa + +revision = "0002_merge_to_asset_references" +down_revision = "0001_assets" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # Drop old tables (order matters due to FK constraints) + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") + op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") + op.drop_table("asset_info_meta") + + op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") + op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") + op.drop_table("asset_info_tags") + + op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") + + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") + op.drop_index("ix_assets_info_created_at", table_name="assets_info") + op.drop_index("ix_assets_info_name", table_name="assets_info") + op.drop_index("ix_assets_info_asset_id", table_name="assets_info") + op.drop_index("ix_assets_info_owner_id", table_name="assets_info") + op.drop_table("assets_info") + + # Truncate assets table (cascades handled by dropping dependent tables first) + op.execute("DELETE FROM assets") + + # Create asset_references table + op.create_table( + "asset_references", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column( + "asset_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("file_path", sa.Text(), nullable=True), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column( + "needs_verify", + sa.Boolean(), + nullable=False, + server_default=sa.text("false"), + ), + sa.Column( + "is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false") + ), + sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column( + "preview_id", + sa.String(length=36), + sa.ForeignKey("assets.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True), + sa.CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + sa.CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), + ) + op.create_index( + "uq_asset_references_file_path", "asset_references", ["file_path"], unique=True + ) + op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"]) + op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"]) + op.create_index("ix_asset_references_name", "asset_references", ["name"]) + op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"]) + op.create_index( + "ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"] + ) + op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"]) + op.create_index( + "ix_asset_references_last_access_time", "asset_references", ["last_access_time"] + ) + op.create_index( + "ix_asset_references_owner_name", "asset_references", ["owner_id", "name"] + ) + op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"]) + + # Create asset_reference_tags table + op.create_table( + "asset_reference_tags", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column( + "tag_name", + sa.String(length=512), + sa.ForeignKey("tags.name", ondelete="RESTRICT"), + nullable=False, + ), + sa.Column( + "origin", sa.String(length=32), nullable=False, server_default="manual" + ), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint( + "asset_reference_id", "tag_name", name="pk_asset_reference_tags" + ), + ) + op.create_index( + "ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"] + ) + op.create_index( + "ix_asset_reference_tags_asset_reference_id", + "asset_reference_tags", + ["asset_reference_id"], + ) + + # Create asset_reference_meta table + op.create_table( + "asset_reference_meta", + sa.Column( + "asset_reference_id", + sa.String(length=36), + sa.ForeignKey("asset_references.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint( + "asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta" + ), + ) + op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"]) + op.create_index( + "ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"] + ) + op.create_index( + "ix_asset_reference_meta_key_val_bool", + "asset_reference_meta", + ["key", "val_bool"], + ) + + +def downgrade() -> None: + """Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema. + + NOTE: Data is not recoverable. The upgrade discards all rows from the old + tables and truncates assets. After downgrade the old schema will be empty. + A filesystem rescan will repopulate data once the older code is running. + """ + # Drop new tables (order matters due to FK constraints) + op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta") + op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta") + op.drop_table("asset_reference_meta") + + op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags") + op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags") + op.drop_table("asset_reference_tags") + + op.drop_index("ix_asset_references_deleted_at", table_name="asset_references") + op.drop_index("ix_asset_references_owner_name", table_name="asset_references") + op.drop_index("ix_asset_references_last_access_time", table_name="asset_references") + op.drop_index("ix_asset_references_created_at", table_name="asset_references") + op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references") + op.drop_index("ix_asset_references_is_missing", table_name="asset_references") + op.drop_index("ix_asset_references_name", table_name="asset_references") + op.drop_index("ix_asset_references_owner_id", table_name="asset_references") + op.drop_index("ix_asset_references_asset_id", table_name="asset_references") + op.drop_index("uq_asset_references_file_path", table_name="asset_references") + op.drop_table("asset_references") + + # Truncate assets (upgrade deleted all rows; downgrade starts fresh too) + op.execute("DELETE FROM assets") + + # Recreate old tables from 0001_assets schema + op.create_table( + "assets_info", + sa.Column("id", sa.String(length=36), primary_key=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), + sa.Column("name", sa.String(length=512), nullable=False), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True), + sa.Column("user_metadata", sa.JSON(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), + sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + ) + op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) + op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"]) + op.create_index("ix_assets_info_name", "assets_info", ["name"]) + op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) + op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) + + op.create_table( + "asset_cache_state", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False), + sa.Column("file_path", sa.Text(), nullable=False), + sa.Column("mtime_ns", sa.BigInteger(), nullable=True), + sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), + ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) + op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"]) + + op.create_table( + "asset_info_tags", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), + sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), + sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), + ) + op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) + op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) + + op.create_table( + "asset_info_meta", + sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("key", sa.String(length=256), nullable=False), + sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"), + sa.Column("val_str", sa.String(length=2048), nullable=True), + sa.Column("val_num", sa.Numeric(38, 10), nullable=True), + sa.Column("val_bool", sa.Boolean(), nullable=True), + sa.Column("val_json", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"), + ) + op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"]) + op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"]) + op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) + op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) diff --git a/alembic_db/versions/0003_add_metadata_job_id.py b/alembic_db/versions/0003_add_metadata_job_id.py new file mode 100644 index 000000000..2a14ee924 --- /dev/null +++ b/alembic_db/versions/0003_add_metadata_job_id.py @@ -0,0 +1,98 @@ +""" +Add system_metadata and job_id columns to asset_references. +Change preview_id FK from assets.id to asset_references.id. + +Revision ID: 0003_add_metadata_job_id +Revises: 0002_merge_to_asset_references +Create Date: 2026-03-09 +""" + +from alembic import op +import sqlalchemy as sa + +from app.database.models import NAMING_CONVENTION + +revision = "0003_add_metadata_job_id" +down_revision = "0002_merge_to_asset_references" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("asset_references") as batch_op: + batch_op.add_column( + sa.Column("system_metadata", sa.JSON(), nullable=True) + ) + batch_op.add_column( + sa.Column("job_id", sa.String(length=36), nullable=True) + ) + + # Change preview_id FK from assets.id to asset_references.id (self-ref). + # Existing values are asset-content IDs that won't match reference IDs, + # so null them out first. + op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL") + with op.batch_alter_table( + "asset_references", naming_convention=NAMING_CONVENTION + ) as batch_op: + batch_op.drop_constraint( + "fk_asset_references_preview_id_assets", type_="foreignkey" + ) + batch_op.create_foreign_key( + "fk_asset_references_preview_id_asset_references", + "asset_references", + ["preview_id"], + ["id"], + ondelete="SET NULL", + ) + batch_op.create_index( + "ix_asset_references_preview_id", ["preview_id"] + ) + + # Purge any all-null meta rows before adding the constraint + op.execute( + "DELETE FROM asset_reference_meta" + " WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL" + ) + with op.batch_alter_table("asset_reference_meta") as batch_op: + batch_op.create_check_constraint( + "ck_asset_reference_meta_has_value", + "val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL", + ) + + +def downgrade() -> None: + # SQLite doesn't reflect CHECK constraints, so we must declare it + # explicitly via table_args for the batch recreate to find it. + # Use the fully-rendered constraint name to avoid the naming convention + # doubling the prefix. + with op.batch_alter_table( + "asset_reference_meta", + table_args=[ + sa.CheckConstraint( + "val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL", + name="ck_asset_reference_meta_has_value", + ), + ], + ) as batch_op: + batch_op.drop_constraint( + "ck_asset_reference_meta_has_value", type_="check" + ) + + with op.batch_alter_table( + "asset_references", naming_convention=NAMING_CONVENTION + ) as batch_op: + batch_op.drop_index("ix_asset_references_preview_id") + batch_op.drop_constraint( + "fk_asset_references_preview_id_asset_references", type_="foreignkey" + ) + batch_op.create_foreign_key( + "fk_asset_references_preview_id_assets", + "assets", + ["preview_id"], + ["id"], + ondelete="SET NULL", + ) + + with op.batch_alter_table("asset_references") as batch_op: + batch_op.drop_column("job_id") + batch_op.drop_column("system_metadata") diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py new file mode 100644 index 000000000..68126b6a5 --- /dev/null +++ b/app/assets/api/routes.py @@ -0,0 +1,804 @@ +import asyncio +import functools +import json +import logging +import os +import urllib.parse +import uuid +from typing import Any + +from aiohttp import web +from pydantic import ValidationError + +import folder_paths +from app import user_manager +from app.assets.api import schemas_in, schemas_out +from app.assets.services import schemas +from app.assets.api.schemas_in import ( + AssetValidationError, + UploadError, +) +from app.assets.helpers import validate_blake3_hash +from app.assets.api.upload import ( + delete_temp_file_if_exists, + parse_multipart_upload, +) +from app.assets.seeder import ScanInProgressError, asset_seeder +from app.assets.services import ( + DependencyMissingError, + HashMismatchError, + apply_tags, + asset_exists, + create_from_hash, + delete_asset_reference, + get_asset_detail, + list_assets_page, + list_tags, + remove_tags, + resolve_asset_for_download, + update_asset_metadata, + upload_from_temp_path, +) +from app.assets.services.tagging import list_tag_histogram + +ROUTES = web.RouteTableDef() +USER_MANAGER: user_manager.UserManager | None = None +_ASSETS_ENABLED = False + + +def _require_assets_feature_enabled(handler): + @functools.wraps(handler) + async def wrapper(request: web.Request) -> web.Response: + if not _ASSETS_ENABLED: + return _build_error_response( + 503, + "SERVICE_DISABLED", + "Assets system is disabled. Start the server with --enable-assets to use this feature.", + ) + return await handler(request) + + return wrapper + + +# UUID regex (canonical hyphenated form, case-insensitive) +UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + + +def get_query_dict(request: web.Request) -> dict[str, Any]: + """Gets a dictionary of query parameters from the request. + + request.query is a MultiMapping[str], needs to be converted to a dict + to be validated by Pydantic. + """ + query_dict = { + key: request.query.getall(key) + if len(request.query.getall(key)) > 1 + else request.query.get(key) + for key in request.query.keys() + } + return query_dict + + +# Note to any custom node developers reading this code: +# The assets system is not yet fully implemented, +# do not rely on the code in /app/assets remaining the same. + + +def register_assets_routes( + app: web.Application, + user_manager_instance: user_manager.UserManager | None = None, +) -> None: + global USER_MANAGER, _ASSETS_ENABLED + if user_manager_instance is not None: + USER_MANAGER = user_manager_instance + _ASSETS_ENABLED = True + app.add_routes(ROUTES) + + +def disable_assets_routes() -> None: + """Disable asset routes at runtime (e.g. after DB init failure).""" + global _ASSETS_ENABLED + _ASSETS_ENABLED = False + + +def _build_error_response( + status: int, code: str, message: str, details: dict | None = None +) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) + + +def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response: + errors = json.loads(ve.json()) + return _build_error_response(400, code, "Validation failed.", {"errors": errors}) + + +def _validate_sort_field(requested: str | None) -> str: + if not requested: + return "created_at" + v = requested.lower() + if v in {"name", "created_at", "updated_at", "size", "last_access_time"}: + return v + return "created_at" + + +def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None: + """Build a /api/view preview URL from asset tags and user_metadata filename.""" + if not user_metadata: + return None + filename = user_metadata.get("filename") + if not filename: + return None + + if "input" in tags: + view_type = "input" + elif "output" in tags: + view_type = "output" + else: + return None + + subfolder = "" + if "/" in filename: + subfolder, filename = filename.rsplit("/", 1) + + encoded_filename = urllib.parse.quote(filename, safe="") + url = f"/api/view?type={view_type}&filename={encoded_filename}" + if subfolder: + url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}" + return url + + +def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset: + """Build an Asset response from a service result.""" + if result.ref.preview_id: + preview_detail = get_asset_detail(result.ref.preview_id) + if preview_detail: + preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata) + else: + preview_url = None + else: + preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata) + return schemas_out.Asset( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + size=int(result.asset.size_bytes) if result.asset else None, + mime_type=result.asset.mime_type if result.asset else None, + tags=result.tags, + preview_url=preview_url, + preview_id=result.ref.preview_id, + user_metadata=result.ref.user_metadata or {}, + metadata=result.ref.system_metadata, + job_id=result.ref.job_id, + prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat + created_at=result.ref.created_at, + updated_at=result.ref.updated_at, + last_access_time=result.ref.last_access_time, + ) + + +@ROUTES.head("/api/assets/hash/{hash}") +@_require_assets_feature_enabled +async def head_asset_by_hash(request: web.Request) -> web.Response: + hash_str = request.match_info.get("hash", "").strip().lower() + try: + hash_str = validate_blake3_hash(hash_str) + except ValueError: + return _build_error_response( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + exists = asset_exists(hash_str) + return web.Response(status=200 if exists else 404) + + +@ROUTES.get("/api/assets") +@_require_assets_feature_enabled +async def list_assets_route(request: web.Request) -> web.Response: + """ + GET request to list assets. + """ + query_dict = get_query_dict(request) + try: + q = schemas_in.ListAssetsQuery.model_validate(query_dict) + except ValidationError as ve: + return _build_validation_error_response("INVALID_QUERY", ve) + + sort = _validate_sort_field(q.sort) + order_candidate = (q.order or "desc").lower() + order = order_candidate if order_candidate in {"asc", "desc"} else "desc" + + result = list_assets_page( + owner_id=USER_MANAGER.get_request_user_id(request), + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + offset=q.offset, + sort=sort, + order=order, + ) + + summaries = [_build_asset_response(item) for item in result.items] + + payload = schemas_out.AssetsList( + assets=summaries, + total=result.total, + has_more=(q.offset + len(summaries)) < result.total, + ) + return web.json_response(payload.model_dump(mode="json", exclude_none=True)) + + +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") +@_require_assets_feature_enabled +async def get_asset_route(request: web.Request) -> web.Response: + """ + GET request to get an asset's info as JSON. + """ + reference_id = str(uuid.UUID(request.match_info["id"])) + try: + result = get_asset_detail( + reference_id=reference_id, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + if not result: + return _build_error_response( + 404, + "ASSET_NOT_FOUND", + f"AssetReference {reference_id} not found", + {"id": reference_id}, + ) + + payload = _build_asset_response(result) + except ValueError as e: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} + ) + except Exception: + logging.exception( + "get_asset failed for reference_id=%s, owner_id=%s", + reference_id, + USER_MANAGER.get_request_user_id(request), + ) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) + + +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +@_require_assets_feature_enabled +async def download_asset_content(request: web.Request) -> web.Response: + disposition = request.query.get("disposition", "attachment").lower().strip() + if disposition not in {"inline", "attachment"}: + disposition = "attachment" + + try: + result = resolve_asset_for_download( + reference_id=str(uuid.UUID(request.match_info["id"])), + owner_id=USER_MANAGER.get_request_user_id(request), + ) + abs_path = result.abs_path + content_type = result.content_type + filename = result.download_name + except ValueError as ve: + return _build_error_response(404, "ASSET_NOT_FOUND", str(ve)) + except NotImplementedError as nie: + return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + except FileNotFoundError: + return _build_error_response( + 404, "FILE_NOT_FOUND", "Underlying file not found on disk." + ) + + _DANGEROUS_MIME_TYPES = { + "text/html", "text/html-sandboxed", "application/xhtml+xml", + "text/javascript", "text/css", + } + if content_type in _DANGEROUS_MIME_TYPES: + content_type = "application/octet-stream" + + safe_name = (filename or "").replace("\r", "").replace("\n", "") + encoded = urllib.parse.quote(safe_name) + cd = f"{disposition}; filename*=UTF-8''{encoded}" + + file_size = os.path.getsize(abs_path) + size_mb = file_size / (1024 * 1024) + logging.info( + "download_asset_content: path=%s, size=%d bytes (%.2f MB), type=%s, name=%s", + abs_path, + file_size, + size_mb, + content_type, + filename, + ) + + async def stream_file_chunks(): + chunk_size = 64 * 1024 + with open(abs_path, "rb") as f: + while True: + chunk = f.read(chunk_size) + if not chunk: + break + yield chunk + + return web.Response( + body=stream_file_chunks(), + content_type=content_type, + headers={ + "Content-Disposition": cd, + "Content-Length": str(file_size), + "X-Content-Type-Options": "nosniff", + }, + ) + + +@ROUTES.post("/api/assets/from-hash") +@_require_assets_feature_enabled +async def create_asset_from_hash_route(request: web.Request) -> web.Response: + try: + payload = await request.json() + body = schemas_in.CreateFromHashBody.model_validate(payload) + except ValidationError as ve: + return _build_validation_error_response("INVALID_BODY", ve) + except Exception: + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) + + # Derive name from hash if not provided + name = body.name + if name is None: + name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash + + result = create_from_hash( + hash_str=body.hash, + name=name, + tags=body.tags, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + mime_type=body.mime_type, + preview_id=body.preview_id, + ) + if result is None: + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" + ) + + asset = _build_asset_response(result) + payload_out = schemas_out.AssetCreated( + **asset.model_dump(), + created_new=result.created_new, + ) + return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201) + + +@ROUTES.post("/api/assets") +@_require_assets_feature_enabled +async def upload_asset(request: web.Request) -> web.Response: + """Multipart/form-data endpoint for Asset uploads.""" + try: + parsed = await parse_multipart_upload(request, check_hash_exists=asset_exists) + except UploadError as e: + return _build_error_response(e.status, e.code, e.message) + + owner_id = USER_MANAGER.get_request_user_id(request) + + try: + spec = schemas_in.UploadAssetSpec.model_validate( + { + "tags": parsed.tags_raw, + "name": parsed.provided_name, + "user_metadata": parsed.user_metadata_raw, + "hash": parsed.provided_hash, + "mime_type": parsed.provided_mime_type, + "preview_id": parsed.provided_preview_id, + } + ) + except ValidationError as ve: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 400, "INVALID_BODY", f"Validation failed: {ve.json()}" + ) + + if spec.tags and spec.tags[0] == "models": + if ( + len(spec.tags) < 2 + or spec.tags[1] not in folder_paths.folder_names_and_paths + ): + delete_temp_file_if_exists(parsed.tmp_path) + category = spec.tags[1] if len(spec.tags) >= 2 else "" + return _build_error_response( + 400, "INVALID_BODY", f"unknown models category '{category}'" + ) + + try: + # Fast path: hash exists, create AssetReference without writing anything + if spec.hash and parsed.provided_hash_exists is True: + result = create_from_hash( + hash_str=spec.hash, + name=spec.name or (spec.hash.split(":", 1)[1]), + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + owner_id=owner_id, + mime_type=spec.mime_type, + preview_id=spec.preview_id, + ) + if result is None: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist" + ) + delete_temp_file_if_exists(parsed.tmp_path) + else: + # Otherwise, we must have a temp file path to ingest + if not parsed.tmp_path or not os.path.exists(parsed.tmp_path): + return _build_error_response( + 400, + "MISSING_INPUT", + "Provided hash not found and no file uploaded.", + ) + + result = upload_from_temp_path( + temp_path=parsed.tmp_path, + name=spec.name, + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + client_filename=parsed.file_client_name, + owner_id=owner_id, + expected_hash=spec.hash, + mime_type=spec.mime_type, + preview_id=spec.preview_id, + ) + except AssetValidationError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, e.code, str(e)) + except ValueError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "BAD_REQUEST", str(e)) + except HashMismatchError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(400, "HASH_MISMATCH", str(e)) + except DependencyMissingError as e: + delete_temp_file_if_exists(parsed.tmp_path) + return _build_error_response(503, "DEPENDENCY_MISSING", e.message) + except Exception: + delete_temp_file_if_exists(parsed.tmp_path) + logging.exception("upload_asset failed for owner_id=%s", owner_id) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + + asset = _build_asset_response(result) + payload_out = schemas_out.AssetCreated( + **asset.model_dump(), + created_new=result.created_new, + ) + status = 201 if result.created_new else 200 + return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status) + + +@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") +@_require_assets_feature_enabled +async def update_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) + try: + body = schemas_in.UpdateAssetBody.model_validate(await request.json()) + except ValidationError as ve: + return _build_validation_error_response("INVALID_BODY", ve) + except Exception: + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) + + try: + result = update_asset_metadata( + reference_id=reference_id, + name=body.name, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + preview_id=body.preview_id, + ) + payload = _build_asset_response(result) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) + except Exception: + logging.exception( + "update_asset failed for reference_id=%s, owner_id=%s", + reference_id, + USER_MANAGER.get_request_user_id(request), + ) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") +@_require_assets_feature_enabled +async def delete_asset_route(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) + delete_content_param = request.query.get("delete_content") + delete_content = ( + False + if delete_content_param is None + else delete_content_param.lower() not in {"0", "false", "no"} + ) + + try: + deleted = delete_asset_reference( + reference_id=reference_id, + owner_id=USER_MANAGER.get_request_user_id(request), + delete_content_if_orphan=delete_content, + ) + except Exception: + logging.exception( + "delete_asset_reference failed for reference_id=%s, owner_id=%s", + reference_id, + USER_MANAGER.get_request_user_id(request), + ) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + + if not deleted: + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"AssetReference {reference_id} not found." + ) + return web.Response(status=204) + + +@ROUTES.get("/api/tags") +@_require_assets_feature_enabled +async def get_tags(request: web.Request) -> web.Response: + """ + GET request to list all tags based on query parameters. + """ + query_map = dict(request.rel_url.query) + + try: + query = schemas_in.TagsListQuery.model_validate(query_map) + except ValidationError as e: + return _build_error_response( + 400, + "INVALID_QUERY", + "Invalid query parameters", + {"errors": json.loads(e.json())}, + ) + + rows, total = list_tags( + prefix=query.prefix, + limit=query.limit, + offset=query.offset, + order=query.order, + include_zero=query.include_zero, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + + tags = [ + schemas_out.TagUsage(name=name, count=count, type=tag_type) + for (name, tag_type, count) in rows + ] + payload = schemas_out.TagsList( + tags=tags, total=total, has_more=(query.offset + len(tags)) < total + ) + return web.json_response(payload.model_dump(mode="json", exclude_none=True)) + + +@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled +async def add_asset_tags(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) + try: + json_payload = await request.json() + data = schemas_in.TagsAdd.model_validate(json_payload) + except ValidationError as ve: + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags add.", + {"errors": ve.errors()}, + ) + except Exception: + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) + + try: + result = apply_tags( + reference_id=reference_id, + tags=data.tags, + origin="manual", + owner_id=USER_MANAGER.get_request_user_id(request), + ) + payload = schemas_out.TagsAdd( + added=result.added, + already_present=result.already_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) + except Exception: + logging.exception( + "add_tags_to_asset failed for reference_id=%s, owner_id=%s", + reference_id, + USER_MANAGER.get_request_user_id(request), + ) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +@_require_assets_feature_enabled +async def delete_asset_tags(request: web.Request) -> web.Response: + reference_id = str(uuid.UUID(request.match_info["id"])) + try: + json_payload = await request.json() + data = schemas_in.TagsRemove.model_validate(json_payload) + except ValidationError as ve: + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags remove.", + {"errors": ve.errors()}, + ) + except Exception: + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) + + try: + result = remove_tags( + reference_id=reference_id, + tags=data.tags, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + payload = schemas_out.TagsRemove( + removed=result.removed, + not_present=result.not_present, + total_tags=result.total_tags, + ) + except PermissionError as pe: + return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) + except ValueError as ve: + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id} + ) + except Exception: + logging.exception( + "remove_tags_from_asset failed for reference_id=%s, owner_id=%s", + reference_id, + USER_MANAGER.get_request_user_id(request), + ) + return _build_error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) + + +@ROUTES.get("/api/assets/tags/refine") +@_require_assets_feature_enabled +async def get_tags_refine(request: web.Request) -> web.Response: + """GET request to get tag histogram for filtered assets.""" + query_dict = get_query_dict(request) + try: + q = schemas_in.TagsRefineQuery.model_validate(query_dict) + except ValidationError as ve: + return _build_validation_error_response("INVALID_QUERY", ve) + + tag_counts = list_tag_histogram( + owner_id=USER_MANAGER.get_request_user_id(request), + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + ) + payload = schemas_out.TagHistogram(tag_counts=tag_counts) + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) + + +@ROUTES.post("/api/assets/seed") +@_require_assets_feature_enabled +async def seed_assets(request: web.Request) -> web.Response: + """Trigger asset seeding for specified roots (models, input, output). + + Query params: + wait: If "true", block until scan completes (synchronous behavior for tests) + + Returns: + 202 Accepted if scan started + 409 Conflict if scan already running + 200 OK with final stats if wait=true + """ + try: + payload = await request.json() + roots = payload.get("roots", ["models", "input", "output"]) + except Exception: + roots = ["models", "input", "output"] + + valid_roots = tuple(r for r in roots if r in ("models", "input", "output")) + if not valid_roots: + return _build_error_response(400, "INVALID_BODY", "No valid roots specified") + + wait_param = request.query.get("wait", "").lower() + should_wait = wait_param in ("true", "1", "yes") + + started = asset_seeder.start(roots=valid_roots) + if not started: + return web.json_response({"status": "already_running"}, status=409) + + if should_wait: + await asyncio.to_thread(asset_seeder.wait) + status = asset_seeder.get_status() + return web.json_response( + { + "status": "completed", + "progress": { + "scanned": status.progress.scanned if status.progress else 0, + "total": status.progress.total if status.progress else 0, + "created": status.progress.created if status.progress else 0, + "skipped": status.progress.skipped if status.progress else 0, + }, + "errors": status.errors, + }, + status=200, + ) + + return web.json_response({"status": "started"}, status=202) + + +@ROUTES.get("/api/assets/seed/status") +@_require_assets_feature_enabled +async def get_seed_status(request: web.Request) -> web.Response: + """Get current scan status and progress.""" + status = asset_seeder.get_status() + return web.json_response( + { + "state": status.state.value, + "progress": { + "scanned": status.progress.scanned, + "total": status.progress.total, + "created": status.progress.created, + "skipped": status.progress.skipped, + } + if status.progress + else None, + "errors": status.errors, + }, + status=200, + ) + + +@ROUTES.post("/api/assets/seed/cancel") +@_require_assets_feature_enabled +async def cancel_seed(request: web.Request) -> web.Response: + """Request cancellation of in-progress scan.""" + cancelled = asset_seeder.cancel() + if cancelled: + return web.json_response({"status": "cancelling"}, status=200) + return web.json_response({"status": "idle"}, status=200) + + +@ROUTES.post("/api/assets/prune") +@_require_assets_feature_enabled +async def mark_missing_assets(request: web.Request) -> web.Response: + """Mark assets as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and metadata + are preserved, but references are flagged as missing. They can be + restored if the file reappears in a future scan. + + Returns: + 200 OK with count of marked assets + 409 Conflict if a scan is currently running + """ + try: + marked = asset_seeder.mark_missing_outside_prefixes() + except ScanInProgressError: + return web.json_response( + {"status": "scan_running", "marked": 0}, + status=409, + ) + return web.json_response({"status": "completed", "marked": marked}, status=200) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py new file mode 100644 index 000000000..186a6ae1e --- /dev/null +++ b/app/assets/api/schemas_in.py @@ -0,0 +1,343 @@ +import json +from dataclasses import dataclass +from typing import Any, Literal + +from app.assets.helpers import validate_blake3_hash +from pydantic import ( + BaseModel, + ConfigDict, + Field, + conint, + field_validator, + model_validator, +) + + +class UploadError(Exception): + """Error during upload parsing with HTTP status and code.""" + + def __init__(self, status: int, code: str, message: str): + super().__init__(message) + self.status = status + self.code = code + self.message = message + + +class AssetValidationError(Exception): + """Validation error in asset processing (invalid tags, metadata, etc.).""" + + def __init__(self, code: str, message: str): + super().__init__(message) + self.code = code + self.message = message + + +@dataclass +class ParsedUpload: + """Result of parsing a multipart upload request.""" + + file_present: bool + file_written: int + file_client_name: str | None + tmp_path: str | None + tags_raw: list[str] + provided_name: str | None + user_metadata_raw: str | None + provided_hash: str | None + provided_hash_exists: bool | None + provided_mime_type: str | None = None + provided_preview_id: str | None = None + + +class ListAssetsQuery(BaseModel): + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: str | None = None + + # Accept either a JSON string (query param) or a dict + metadata_filter: dict[str, Any] | None = None + + limit: conint(ge=1, le=500) = 20 + offset: conint(ge=0) = 0 + + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = ( + "created_at" + ) + order: Literal["asc", "desc"] = "desc" + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + # Accept "a,b,c" or ["a","b"] (we are liberal in what we accept) + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None + + +class UpdateAssetBody(BaseModel): + name: str | None = None + user_metadata: dict[str, Any] | None = None + preview_id: str | None = None # references an asset_reference id, not an asset id + + @model_validator(mode="after") + def _validate_at_least_one_field(self): + if all( + v is None + for v in (self.name, self.user_metadata, self.preview_id) + ): + raise ValueError( + "Provide at least one of: name, user_metadata, preview_id." + ) + return self + + +class CreateFromHashBody(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + hash: str + name: str | None = None + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + mime_type: str | None = None + preview_id: str | None = None # references an asset_reference id, not an asset id + + @field_validator("hash") + @classmethod + def _require_blake3(cls, v): + return validate_blake3_hash(v or "") + + @field_validator("tags", mode="before") + @classmethod + def _normalize_tags_field(cls, v): + if v is None: + return [] + if isinstance(v, list): + out = [str(t).strip().lower() for t in v if str(t).strip()] + seen = set() + dedup = [] + for t in out: + if t not in seen: + seen.add(t) + dedup.append(t) + return dedup + if isinstance(v, str): + return [t.strip().lower() for t in v.split(",") if t.strip()] + return [] + + +class TagsRefineQuery(BaseModel): + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: str | None = None + metadata_filter: dict[str, Any] | None = None + limit: conint(ge=1, le=1000) = 100 + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None + + +class TagsListQuery(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + prefix: str | None = Field(None, min_length=1, max_length=256) + limit: int = Field(100, ge=1, le=1000) + offset: int = Field(0, ge=0, le=10_000_000) + order: Literal["count_desc", "name_asc"] = "count_desc" + include_zero: bool = True + + @field_validator("prefix") + @classmethod + def normalize_prefix(cls, v: str | None) -> str | None: + if v is None: + return v + v = v.strip() + return v.lower() or None + + +class TagsAdd(BaseModel): + model_config = ConfigDict(extra="ignore") + tags: list[str] = Field(..., min_length=1) + + @field_validator("tags") + @classmethod + def normalize_tags(cls, v: list[str]) -> list[str]: + out = [] + for t in v: + if not isinstance(t, str): + raise TypeError("tags must be strings") + tnorm = t.strip().lower() + if tnorm: + out.append(tnorm) + seen = set() + deduplicated = [] + for x in out: + if x not in seen: + seen.add(x) + deduplicated.append(x) + return deduplicated + + +class TagsRemove(TagsAdd): + pass + + +class UploadAssetSpec(BaseModel): + """Upload Asset operation. + + - tags: optional list; if provided, first is root ('models'|'input'|'output'); + if root == 'models', second must be a valid category + - name: display name + - user_metadata: arbitrary JSON object (optional) + - hash: optional canonical 'blake3:' for validation / fast-path + - mime_type: optional MIME type override + - preview_id: optional asset_reference ID for preview + + Files are stored using the content hash as filename stem. + """ + + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + tags: list[str] = Field(default_factory=list) + name: str | None = Field(default=None, max_length=512, description="Display Name") + user_metadata: dict[str, Any] = Field(default_factory=dict) + hash: str | None = Field(default=None) + mime_type: str | None = Field(default=None) + preview_id: str | None = Field(default=None) # references an asset_reference id + + @field_validator("hash", mode="before") + @classmethod + def _parse_hash(cls, v): + if v is None: + return None + s = str(v).strip() + if not s: + return None + return validate_blake3_hash(s) + + @field_validator("tags", mode="before") + @classmethod + def _parse_tags(cls, v): + """ + Accepts a list of strings (possibly multiple form fields), + where each string can be: + - JSON array (e.g., '["models","loras","foo"]') + - comma-separated ('models, loras, foo') + - single token ('models') + Returns a normalized, deduplicated, ordered list. + """ + items: list[str] = [] + if v is None: + return [] + if isinstance(v, str): + v = [v] + + if isinstance(v, list): + for item in v: + if item is None: + continue + s = str(item).strip() + if not s: + continue + if s.startswith("["): + try: + arr = json.loads(s) + if isinstance(arr, list): + items.extend(str(x) for x in arr) + continue + except Exception: + pass # fallback to CSV parse below + items.extend([p for p in s.split(",") if p.strip()]) + else: + return [] + + # normalize + dedupe + norm = [] + seen = set() + for t in items: + tnorm = str(t).strip().lower() + if tnorm and tnorm not in seen: + seen.add(tnorm) + norm.append(tnorm) + return norm + + @field_validator("user_metadata", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v or {} + if isinstance(v, str): + s = v.strip() + if not s: + return {} + try: + parsed = json.loads(s) + except Exception as e: + raise ValueError(f"user_metadata must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("user_metadata must be a JSON object") + return parsed + return {} + + @model_validator(mode="after") + def _validate_order(self): + if not self.tags: + raise ValueError("at least one tag is required for uploads") + root = self.tags[0] + if root not in {"models", "input", "output"}: + raise ValueError("first tag must be one of: models, input, output") + if root == "models": + if len(self.tags) < 2: + raise ValueError( + "models uploads require a category tag as the second tag" + ) + return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py new file mode 100644 index 000000000..d99b1098d --- /dev/null +++ b/app/assets/api/schemas_out.py @@ -0,0 +1,72 @@ +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, field_serializer + + +class Asset(BaseModel): + """API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob; + ``id`` here is the AssetReference id, not the content-addressed Asset id.""" + + id: str + name: str + asset_hash: str | None = None + size: int | None = None + mime_type: str | None = None + tags: list[str] = Field(default_factory=list) + preview_url: str | None = None + preview_id: str | None = None # references an asset_reference id, not an asset id + user_metadata: dict[str, Any] = Field(default_factory=dict) + is_immutable: bool = False + metadata: dict[str, Any] | None = None + job_id: str | None = None + prompt_id: str | None = None # deprecated: use job_id + created_at: datetime + updated_at: datetime + last_access_time: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("created_at", "updated_at", "last_access_time") + def _serialize_datetime(self, v: datetime | None, _info): + return v.isoformat() if v else None + + +class AssetCreated(Asset): + created_new: bool + + +class AssetsList(BaseModel): + assets: list[Asset] + total: int + has_more: bool + + +class TagUsage(BaseModel): + name: str + count: int + type: str + + +class TagsList(BaseModel): + tags: list[TagUsage] = Field(default_factory=list) + total: int + has_more: bool + + +class TagsAdd(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + added: list[str] = Field(default_factory=list) + already_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) + + +class TagsRemove(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + removed: list[str] = Field(default_factory=list) + not_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) + + +class TagHistogram(BaseModel): + tag_counts: dict[str, int] diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py new file mode 100644 index 000000000..13d3d372c --- /dev/null +++ b/app/assets/api/upload.py @@ -0,0 +1,185 @@ +import logging +import os +import uuid +from typing import Callable + +from aiohttp import web + +import folder_paths +from app.assets.api.schemas_in import ParsedUpload, UploadError +from app.assets.helpers import validate_blake3_hash + + +def normalize_and_validate_hash(s: str) -> str: + """Validate and normalize a hash string. + + Returns canonical 'blake3:' or raises UploadError. + """ + try: + return validate_blake3_hash(s) + except ValueError: + raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") + + +async def parse_multipart_upload( + request: web.Request, + check_hash_exists: Callable[[str], bool], +) -> ParsedUpload: + """ + Parse a multipart/form-data upload request. + + Args: + request: The aiohttp request + check_hash_exists: Callable(hash_str) -> bool to check if a hash exists + + Returns: + ParsedUpload with parsed fields and temp file path + + Raises: + UploadError: On validation or I/O errors + """ + if not (request.content_type or "").lower().startswith("multipart/"): + raise UploadError( + 415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads." + ) + + reader = await request.multipart() + + file_present = False + file_client_name: str | None = None + tags_raw: list[str] = [] + provided_name: str | None = None + user_metadata_raw: str | None = None + provided_hash: str | None = None + provided_hash_exists: bool | None = None + provided_mime_type: str | None = None + provided_preview_id: str | None = None + + file_written = 0 + tmp_path: str | None = None + + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + raise UploadError( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) + + if s: + provided_hash = normalize_and_validate_hash(s) + try: + provided_hash_exists = check_hash_exists(provided_hash) + except Exception as e: + logging.exception( + "check_hash_exists failed for hash=%s: %s", provided_hash, e + ) + raise UploadError( + 500, + "HASH_CHECK_FAILED", + "Backend error while checking asset hash.", + ) + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # Hash exists - drain file but don't write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file." + ) + continue + + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + delete_temp_file_if_exists(tmp_path) + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file." + ) + + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + elif fname == "id": + raise UploadError( + 400, + "UNSUPPORTED_FIELD", + "Client-provided 'id' is not supported. Asset IDs are assigned by the server.", + ) + elif fname == "mime_type": + provided_mime_type = ((await field.text()) or "").strip() or None + elif fname == "preview_id": + provided_preview_id = ((await field.text()) or "").strip() or None + + if not file_present and not (provided_hash and provided_hash_exists): + raise UploadError( + 400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'." + ) + + if ( + file_present + and file_written == 0 + and not (provided_hash and provided_hash_exists) + ): + delete_temp_file_if_exists(tmp_path) + raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + return ParsedUpload( + file_present=file_present, + file_written=file_written, + file_client_name=file_client_name, + tmp_path=tmp_path, + tags_raw=tags_raw, + provided_name=provided_name, + user_metadata_raw=user_metadata_raw, + provided_hash=provided_hash, + provided_hash_exists=provided_hash_exists, + provided_mime_type=provided_mime_type, + provided_preview_id=provided_preview_id, + ) + + +def delete_temp_file_if_exists(tmp_path: str | None) -> None: + """Safely remove a temp file and its parent directory if empty.""" + if tmp_path: + try: + if os.path.exists(tmp_path): + os.remove(tmp_path) + except OSError as e: + logging.debug("Failed to delete temp file %s: %s", tmp_path, e) + try: + parent = os.path.dirname(tmp_path) + if parent and os.path.isdir(parent): + os.rmdir(parent) # only succeeds if empty + except OSError: + pass diff --git a/app/assets/database/models.py b/app/assets/database/models.py new file mode 100644 index 000000000..a3af8a192 --- /dev/null +++ b/app/assets/database/models.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import uuid +from datetime import datetime +from typing import Any + +from sqlalchemy import ( + JSON, + BigInteger, + Boolean, + CheckConstraint, + DateTime, + ForeignKey, + Index, + Integer, + Numeric, + String, + Text, +) +from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship + +from app.assets.helpers import get_utc_now +from app.database.models import Base + + +class Asset(Base): + __tablename__ = "assets" + + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + hash: Mapped[str | None] = mapped_column(String(256), nullable=True) + size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) + mime_type: Mapped[str | None] = mapped_column(String(255)) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + + references: Mapped[list[AssetReference]] = relationship( + "AssetReference", + back_populates="asset", + primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id), + foreign_keys=lambda: [AssetReference.asset_id], + cascade="all,delete-orphan", + passive_deletes=True, + ) + + # preview_id on AssetReference is a self-referential FK to asset_references.id + + __table_args__ = ( + Index("uq_assets_hash", "hash", unique=True), + Index("ix_assets_mime_type", "mime_type"), + CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), + ) + + def __repr__(self) -> str: + return f"" + + +class AssetReference(Base): + """Unified model combining file cache state and user-facing metadata. + + Each row represents either: + - A filesystem reference (file_path is set) with cache state + - An API-created reference (file_path is NULL) without cache state + """ + + __tablename__ = "asset_references" + + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) + asset_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False + ) + + # Cache state fields (from former AssetCacheState) + file_path: Mapped[str | None] = mapped_column(Text, nullable=True) + mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + + # Info fields (from former AssetInfo) + owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") + name: Mapped[str] = mapped_column(String(512), nullable=False) + preview_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("asset_references.id", ondelete="SET NULL") + ) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON(none_as_null=True) + ) + system_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON(none_as_null=True), nullable=True, default=None + ) + job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + last_access_time: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + deleted_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=False), nullable=True, default=None + ) + + asset: Mapped[Asset] = relationship( + "Asset", + back_populates="references", + foreign_keys=[asset_id], + lazy="selectin", + ) + preview_ref: Mapped[AssetReference | None] = relationship( + "AssetReference", + foreign_keys=[preview_id], + remote_side=lambda: [AssetReference.id], + ) + + metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship( + back_populates="asset_reference", + cascade="all,delete-orphan", + passive_deletes=True, + ) + + tag_links: Mapped[list[AssetReferenceTag]] = relationship( + back_populates="asset_reference", + cascade="all,delete-orphan", + passive_deletes=True, + overlaps="tags,asset_references", + ) + + tags: Mapped[list[Tag]] = relationship( + secondary="asset_reference_tags", + back_populates="asset_references", + lazy="selectin", + viewonly=True, + overlaps="tag_links,asset_reference_links,asset_references,tag", + ) + + __table_args__ = ( + Index("uq_asset_references_file_path", "file_path", unique=True), + Index("ix_asset_references_asset_id", "asset_id"), + Index("ix_asset_references_owner_id", "owner_id"), + Index("ix_asset_references_name", "name"), + Index("ix_asset_references_is_missing", "is_missing"), + Index("ix_asset_references_enrichment_level", "enrichment_level"), + Index("ix_asset_references_created_at", "created_at"), + Index("ix_asset_references_last_access_time", "last_access_time"), + Index("ix_asset_references_deleted_at", "deleted_at"), + Index("ix_asset_references_preview_id", "preview_id"), + Index("ix_asset_references_owner_name", "owner_id", "name"), + CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" + ), + CheckConstraint( + "enrichment_level >= 0 AND enrichment_level <= 2", + name="ck_ar_enrichment_level_range", + ), + ) + + def __repr__(self) -> str: + path_part = f" path={self.file_path!r}" if self.file_path else "" + return f"" + + +class AssetReferenceMeta(Base): + __tablename__ = "asset_reference_meta" + + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, + ) + key: Mapped[str] = mapped_column(String(256), primary_key=True) + ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) + + val_str: Mapped[str | None] = mapped_column(String(2048), nullable=True) + val_num: Mapped[float | None] = mapped_column(Numeric(38, 10), nullable=True) + val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True) + val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True) + + asset_reference: Mapped[AssetReference] = relationship( + back_populates="metadata_entries" + ) + + __table_args__ = ( + Index("ix_asset_reference_meta_key", "key"), + Index("ix_asset_reference_meta_key_val_str", "key", "val_str"), + Index("ix_asset_reference_meta_key_val_num", "key", "val_num"), + Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"), + CheckConstraint( + "val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL", + name="has_value", + ), + ) + + +class AssetReferenceTag(Base): + __tablename__ = "asset_reference_tags" + + asset_reference_id: Mapped[str] = mapped_column( + String(36), + ForeignKey("asset_references.id", ondelete="CASCADE"), + primary_key=True, + ) + tag_name: Mapped[str] = mapped_column( + String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True + ) + origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") + added_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + + asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links") + tag: Mapped[Tag] = relationship(back_populates="asset_reference_links") + + __table_args__ = ( + Index("ix_asset_reference_tags_tag_name", "tag_name"), + Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"), + ) + + +class Tag(Base): + __tablename__ = "tags" + + name: Mapped[str] = mapped_column(String(512), primary_key=True) + tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") + + asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship( + back_populates="tag", + overlaps="asset_references,tags", + ) + asset_references: Mapped[list[AssetReference]] = relationship( + secondary="asset_reference_tags", + back_populates="tags", + viewonly=True, + overlaps="asset_reference_links,tag_links,tags,asset_reference", + ) + + __table_args__ = (Index("ix_tags_tag_type", "tag_type"),) + + def __repr__(self) -> str: + return f"" diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py new file mode 100644 index 000000000..9949e84e1 --- /dev/null +++ b/app/assets/database/queries/__init__.py @@ -0,0 +1,137 @@ +from app.assets.database.queries.asset import ( + asset_exists_by_hash, + bulk_insert_assets, + create_stub_asset, + get_asset_by_hash, + get_existing_asset_ids, + reassign_asset_references, + update_asset_hash_and_mime, + upsert_asset, +) +from app.assets.database.queries.asset_reference import ( + CacheStateRow, + UnenrichedReferenceRow, + bulk_insert_references_ignore_conflicts, + bulk_update_enrichment_level, + count_active_siblings, + bulk_update_is_missing, + bulk_update_needs_verify, + convert_metadata_to_rows, + delete_assets_by_ids, + delete_orphaned_seed_asset, + delete_reference_by_id, + delete_references_by_ids, + fetch_reference_and_asset, + fetch_reference_asset_and_tags, + get_or_create_reference, + get_reference_by_file_path, + get_reference_by_id, + get_reference_with_owner_check, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_references_for_prefixes, + get_unenriched_references, + get_unreferenced_unhashed_asset_ids, + insert_reference, + list_all_file_paths_by_asset_id, + list_references_by_asset_id, + list_references_page, + mark_references_missing_outside_prefixes, + rebuild_metadata_projection, + reference_exists, + reference_exists_for_asset_id, + restore_references_by_paths, + set_reference_metadata, + set_reference_preview, + set_reference_system_metadata, + soft_delete_reference_by_id, + update_reference_access_time, + update_reference_name, + update_is_missing_by_asset_id, + update_reference_timestamps, + update_reference_updated_at, + upsert_reference, +) +from app.assets.database.queries.tags import ( + AddTagsResult, + RemoveTagsResult, + SetTagsResult, + add_missing_tag_for_asset_id, + add_tags_to_reference, + bulk_insert_tags_and_meta, + ensure_tags_exist, + get_reference_tags, + list_tag_counts_for_filtered_assets, + list_tags_with_usage, + remove_missing_tag_for_asset_id, + remove_tags_from_reference, + set_reference_tags, + validate_tags_exist, +) + +__all__ = [ + "AddTagsResult", + "CacheStateRow", + "RemoveTagsResult", + "SetTagsResult", + "UnenrichedReferenceRow", + "add_missing_tag_for_asset_id", + "add_tags_to_reference", + "asset_exists_by_hash", + "bulk_insert_assets", + "bulk_insert_references_ignore_conflicts", + "bulk_insert_tags_and_meta", + "bulk_update_enrichment_level", + "count_active_siblings", + "create_stub_asset", + "bulk_update_is_missing", + "bulk_update_needs_verify", + "convert_metadata_to_rows", + "delete_assets_by_ids", + "delete_orphaned_seed_asset", + "delete_reference_by_id", + "delete_references_by_ids", + "ensure_tags_exist", + "fetch_reference_and_asset", + "fetch_reference_asset_and_tags", + "get_asset_by_hash", + "get_existing_asset_ids", + "get_or_create_reference", + "get_reference_by_file_path", + "get_reference_by_id", + "get_reference_with_owner_check", + "get_reference_ids_by_ids", + "get_reference_tags", + "get_references_by_paths_and_asset_ids", + "get_references_for_prefixes", + "get_unenriched_references", + "get_unreferenced_unhashed_asset_ids", + "insert_reference", + "list_all_file_paths_by_asset_id", + "list_references_by_asset_id", + "list_references_page", + "list_tag_counts_for_filtered_assets", + "list_tags_with_usage", + "mark_references_missing_outside_prefixes", + "reassign_asset_references", + "rebuild_metadata_projection", + "reference_exists", + "reference_exists_for_asset_id", + "remove_missing_tag_for_asset_id", + "remove_tags_from_reference", + "restore_references_by_paths", + "set_reference_metadata", + "set_reference_preview", + "set_reference_system_metadata", + "soft_delete_reference_by_id", + "set_reference_tags", + "update_asset_hash_and_mime", + "update_is_missing_by_asset_id", + "update_reference_access_time", + "update_reference_name", + "update_reference_timestamps", + "update_reference_updated_at", + "upsert_asset", + "upsert_reference", + "validate_tags_exist", +] diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py new file mode 100644 index 000000000..cc7168431 --- /dev/null +++ b/app/assets/database/queries/asset.py @@ -0,0 +1,152 @@ +import sqlalchemy as sa +from sqlalchemy import select +from sqlalchemy.dialects import sqlite +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks + + +def asset_exists_by_hash( + session: Session, + asset_hash: str, +) -> bool: + """ + Check if an asset with a given hash exists in database. + """ + row = ( + session.execute( + select(sa.literal(True)) + .select_from(Asset) + .where(Asset.hash == asset_hash) + .limit(1) + ) + ).first() + return row is not None + + +def get_asset_by_hash( + session: Session, + asset_hash: str, +) -> Asset | None: + return ( + (session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))) + .scalars() + .first() + ) + + +def upsert_asset( + session: Session, + asset_hash: str, + size_bytes: int, + mime_type: str | None = None, +) -> tuple[Asset, bool, bool]: + """Upsert an Asset by hash. Returns (asset, created, updated).""" + vals = {"hash": asset_hash, "size_bytes": int(size_bytes)} + if mime_type: + vals["mime_type"] = mime_type + + ins = ( + sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + asset = ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + .scalars() + .first() + ) + if not asset: + raise RuntimeError("Asset row not found after upsert.") + + updated = False + if not created: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and not asset.mime_type: + asset.mime_type = mime_type + changed = True + if changed: + updated = True + + return asset, created, updated + + +def create_stub_asset( + session: Session, + size_bytes: int, + mime_type: str | None = None, +) -> Asset: + """Create a new asset with no hash (stub for later enrichment).""" + asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None) + session.add(asset) + session.flush() + return asset + + +def bulk_insert_assets( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash.""" + if not rows: + return + ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash]) + for chunk in iter_chunks(rows, calculate_rows_per_statement(5)): + session.execute(ins, chunk) + + +def get_existing_asset_ids( + session: Session, + asset_ids: list[str], +) -> set[str]: + """Return the subset of asset_ids that exist in the database.""" + if not asset_ids: + return set() + found: set[str] = set() + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + rows = session.execute( + select(Asset.id).where(Asset.id.in_(chunk)) + ).fetchall() + found.update(row[0] for row in rows) + return found + + +def update_asset_hash_and_mime( + session: Session, + asset_id: str, + asset_hash: str | None = None, + mime_type: str | None = None, +) -> bool: + """Update asset hash and/or mime_type. Returns True if asset was found.""" + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset_hash is not None: + asset.hash = asset_hash + if mime_type is not None and not asset.mime_type: + asset.mime_type = mime_type + return True + + +def reassign_asset_references( + session: Session, + from_asset_id: str, + to_asset_id: str, + reference_id: str, +) -> None: + """Reassign a reference from one asset to another. + + Used when merging a stub asset into an existing asset with the same hash. + """ + ref = session.get(AssetReference, reference_id) + if ref and ref.asset_id == from_asset_id: + ref.asset_id = to_asset_id + + session.flush() diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py new file mode 100644 index 000000000..8b90ae511 --- /dev/null +++ b/app/assets/database/queries/asset_reference.py @@ -0,0 +1,1045 @@ +"""Query functions for the unified AssetReference table. + +This module replaces the separate asset_info.py and cache_state.py query modules, +providing a unified interface for the merged asset_references table. +""" + +from collections import defaultdict +from datetime import datetime +from decimal import Decimal +from typing import NamedTuple, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session, noload + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + MAX_BIND_PARAMS, + apply_metadata_filter, + apply_tag_filters, + build_prefix_like_conditions, + build_visible_owner_clause, + calculate_rows_per_statement, + iter_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now + + +def _check_is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + + +def _scalar_to_row(key: str, ordinal: int, value) -> dict: + """Convert a scalar value to a typed projection row.""" + if isinstance(value, bool): + return {"key": key, "ordinal": ordinal, "val_bool": bool(value)} + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return {"key": key, "ordinal": ordinal, "val_num": num} + if isinstance(value, str): + return {"key": key, "ordinal": ordinal, "val_str": value} + return {"key": key, "ordinal": ordinal, "val_json": value} + + +def convert_metadata_to_rows(key: str, value) -> list[dict]: + """Turn a metadata key/value into typed projection rows.""" + if value is None: + return [] + + if _check_is_scalar(value): + return [_scalar_to_row(key, 0, value)] + + if isinstance(value, list): + if all(_check_is_scalar(x) for x in value): + return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None] + return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None] + + return [{"key": key, "ordinal": 0, "val_json": value}] + + + + +def get_reference_by_id( + session: Session, + reference_id: str, +) -> AssetReference | None: + return session.get(AssetReference, reference_id) + + +def get_reference_with_owner_check( + session: Session, + reference_id: str, + owner_id: str, +) -> AssetReference: + """Fetch a reference and verify ownership. + + Raises: + ValueError: if reference not found or soft-deleted + PermissionError: if owner_id doesn't match + """ + ref = get_reference_by_id(session, reference_id=reference_id) + if not ref or ref.deleted_at is not None: + raise ValueError(f"AssetReference {reference_id} not found") + if ref.owner_id and ref.owner_id != owner_id: + raise PermissionError("not owner") + return ref + + +def get_reference_by_file_path( + session: Session, + file_path: str, +) -> AssetReference | None: + """Get a reference by its file path.""" + return ( + session.execute( + select(AssetReference).where(AssetReference.file_path == file_path).limit(1) + ) + .scalars() + .first() + ) + + +def count_active_siblings( + session: Session, + asset_id: str, + exclude_reference_id: str, +) -> int: + """Count active (non-deleted) references to an asset, excluding one reference.""" + return ( + session.query(AssetReference) + .filter( + AssetReference.asset_id == asset_id, + AssetReference.id != exclude_reference_id, + AssetReference.deleted_at.is_(None), + ) + .count() + ) + + +def reference_exists_for_asset_id( + session: Session, + asset_id: str, +) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetReference) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.deleted_at.is_(None)) + .limit(1) + ) + return session.execute(q).first() is not None + + +def reference_exists( + session: Session, + reference_id: str, +) -> bool: + """Return True if a reference with the given ID exists (not soft-deleted).""" + q = ( + select(sa.literal(True)) + .select_from(AssetReference) + .where(AssetReference.id == reference_id) + .where(AssetReference.deleted_at.is_(None)) + .limit(1) + ) + return session.execute(q).first() is not None + + +def insert_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> AssetReference | None: + """Insert a new AssetReference. Returns None if unique constraint violated.""" + now = get_utc_now() + try: + with session.begin_nested(): + ref = AssetReference( + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + except IntegrityError: + return None + + +def get_or_create_reference( + session: Session, + asset_id: str, + name: str, + owner_id: str = "", + file_path: str | None = None, + mtime_ns: int | None = None, + preview_id: str | None = None, +) -> tuple[AssetReference, bool]: + """Get existing or create new AssetReference. + + For filesystem references (file_path is set), uniqueness is by file_path. + For API references (file_path is None), we look for matching + asset_id + owner_id + name. + + Returns (reference, created). + """ + ref = insert_reference( + session, + asset_id=asset_id, + name=name, + owner_id=owner_id, + file_path=file_path, + mtime_ns=mtime_ns, + preview_id=preview_id, + ) + if ref: + return ref, True + + # Find existing - priority to file_path match, then name match + if file_path: + existing = get_reference_by_file_path(session, file_path) + else: + existing = ( + session.execute( + select(AssetReference) + .where( + AssetReference.asset_id == asset_id, + AssetReference.name == name, + AssetReference.owner_id == owner_id, + AssetReference.file_path.is_(None), + ) + .limit(1) + ) + .unique() + .scalar_one_or_none() + ) + if not existing: + raise RuntimeError("Failed to find AssetReference after insert conflict.") + return existing, False + + +def update_reference_timestamps( + session: Session, + reference: AssetReference, + preview_id: str | None = None, +) -> None: + """Update timestamps and optionally preview_id on existing AssetReference.""" + now = get_utc_now() + if preview_id and reference.preview_id != preview_id: + reference.preview_id = preview_id + reference.updated_at = now + + +def list_references_page( + session: Session, + owner_id: str = "", + limit: int = 100, + offset: int = 0, + name_contains: str | None = None, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + metadata_filter: dict | None = None, + sort: str | None = None, + order: str | None = None, +) -> tuple[list[AssetReference], dict[str, list[str]], int]: + """List references with pagination, filtering, and sorting. + + Returns (references, tag_map, total_count). + """ + base = ( + select(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .options(noload(AssetReference.tags)) + ) + + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + + base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_metadata_filter(base, metadata_filter) + + sort = (sort or "created_at").lower() + order = (order or "desc").lower() + sort_map = { + "name": AssetReference.name, + "created_at": AssetReference.created_at, + "updated_at": AssetReference.updated_at, + "last_access_time": AssetReference.last_access_time, + "size": Asset.size_bytes, + } + sort_col = sort_map.get(sort, AssetReference.created_at) + sort_exp = sort_col.desc() if order == "desc" else sort_col.asc() + + base = base.order_by(sort_exp).limit(limit).offset(offset) + + count_stmt = ( + select(sa.func.count()) + .select_from(AssetReference) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + ) + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + count_stmt = count_stmt.where( + AssetReference.name.ilike(f"%{escaped}%", escape=esc) + ) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_metadata_filter(count_stmt, metadata_filter) + + total = int(session.execute(count_stmt).scalar_one() or 0) + refs = session.execute(base).unique().scalars().all() + + id_list: list[str] = [r.id for r in refs] + tag_map: dict[str, list[str]] = defaultdict(list) + if id_list: + rows = session.execute( + select(AssetReferenceTag.asset_reference_id, Tag.name) + .join(Tag, Tag.name == AssetReferenceTag.tag_name) + .where(AssetReferenceTag.asset_reference_id.in_(id_list)) + .order_by(AssetReferenceTag.tag_name.asc()) + ) + for ref_id, tag_name in rows.all(): + tag_map[ref_id].append(tag_name) + + return list(refs), tag_map, total + + +def fetch_reference_asset_and_tags( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset, list[str]] | None: + stmt = ( + select(AssetReference, Asset, Tag.name) + .join(Asset, Asset.id == AssetReference.asset_id) + .join( + AssetReferenceTag, + AssetReferenceTag.asset_reference_id == AssetReference.id, + isouter=True, + ) + .join(Tag, Tag.name == AssetReferenceTag.tag_name, isouter=True) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .options(noload(AssetReference.tags)) + .order_by(Tag.name.asc()) + ) + + rows = session.execute(stmt).all() + if not rows: + return None + + first_ref, first_asset, _ = rows[0] + tags: list[str] = [] + seen: set[str] = set() + for _ref, _asset, tag_name in rows: + if tag_name and tag_name not in seen: + seen.add(tag_name) + tags.append(tag_name) + return first_ref, first_asset, tags + + +def fetch_reference_and_asset( + session: Session, + reference_id: str, + owner_id: str = "", +) -> tuple[AssetReference, Asset] | None: + stmt = ( + select(AssetReference, Asset) + .join(Asset, Asset.id == AssetReference.asset_id) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetReference.tags)) + ) + pair = session.execute(stmt).first() + if not pair: + return None + return pair[0], pair[1] + + +def update_reference_access_time( + session: Session, + reference_id: str, + ts: datetime | None = None, + only_if_newer: bool = True, +) -> None: + ts = ts or get_utc_now() + stmt = sa.update(AssetReference).where(AssetReference.id == reference_id) + if only_if_newer: + stmt = stmt.where( + sa.or_( + AssetReference.last_access_time.is_(None), + AssetReference.last_access_time < ts, + ) + ) + session.execute(stmt.values(last_access_time=ts)) + + +def update_reference_name( + session: Session, + reference_id: str, + name: str, +) -> None: + """Update the name of an AssetReference.""" + now = get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(name=name, updated_at=now) + ) + + +def update_reference_updated_at( + session: Session, + reference_id: str, + ts: datetime | None = None, +) -> None: + """Update the updated_at timestamp of an AssetReference.""" + ts = ts or get_utc_now() + session.execute( + sa.update(AssetReference) + .where(AssetReference.id == reference_id) + .values(updated_at=ts) + ) + + +def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None: + """Delete and rebuild AssetReferenceMeta rows from merged system+user metadata. + + The merged dict is ``{**system_metadata, **user_metadata}`` so user keys + override system keys of the same name. + """ + session.execute( + delete(AssetReferenceMeta).where( + AssetReferenceMeta.asset_reference_id == ref.id + ) + ) + session.flush() + + merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})} + if not merged: + return + + rows: list[AssetReferenceMeta] = [] + for k, v in merged.items(): + for r in convert_metadata_to_rows(k, v): + rows.append( + AssetReferenceMeta( + asset_reference_id=ref.id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + session.flush() + + +def set_reference_metadata( + session: Session, + reference_id: str, + user_metadata: dict | None = None, +) -> None: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + ref.user_metadata = user_metadata or {} + ref.updated_at = get_utc_now() + session.flush() + + rebuild_metadata_projection(session, ref) + + +def set_reference_system_metadata( + session: Session, + reference_id: str, + system_metadata: dict | None = None, +) -> None: + """Set system_metadata on a reference and rebuild the merged projection.""" + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + ref.system_metadata = system_metadata or {} + ref.updated_at = get_utc_now() + session.flush() + + rebuild_metadata_projection(session, ref) + + +def delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + stmt = sa.delete(AssetReference).where( + AssetReference.id == reference_id, + build_visible_owner_clause(owner_id), + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def soft_delete_reference_by_id( + session: Session, + reference_id: str, + owner_id: str, +) -> bool: + """Mark a reference as soft-deleted by setting deleted_at timestamp. + + Returns True if the reference was found and marked deleted. + """ + now = get_utc_now() + stmt = ( + sa.update(AssetReference) + .where( + AssetReference.id == reference_id, + AssetReference.deleted_at.is_(None), + build_visible_owner_clause(owner_id), + ) + .values(deleted_at=now) + ) + return int(session.execute(stmt).rowcount or 0) > 0 + + +def set_reference_preview( + session: Session, + reference_id: str, + preview_reference_id: str | None = None, +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + if preview_reference_id is None: + ref.preview_id = None + else: + if not session.get(AssetReference, preview_reference_id): + raise ValueError(f"Preview AssetReference {preview_reference_id} not found") + ref.preview_id = preview_reference_id + + ref.updated_at = get_utc_now() + session.flush() + + +class CacheStateRow(NamedTuple): + """Row from reference query with cache state data.""" + + reference_id: str + file_path: str + mtime_ns: int | None + needs_verify: bool + asset_id: str + asset_hash: str | None + size_bytes: int | None + + +def list_references_by_asset_id( + session: Session, + asset_id: str, +) -> Sequence[AssetReference]: + return ( + session.execute( + select(AssetReference) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .order_by(AssetReference.id.asc()) + ) + .scalars() + .all() + ) + + +def list_all_file_paths_by_asset_id( + session: Session, + asset_id: str, +) -> list[str]: + """Return every file_path for an asset, including soft-deleted/missing refs. + + Used for orphan cleanup where all on-disk files must be removed. + """ + return list( + session.execute( + select(AssetReference.file_path) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.file_path.isnot(None)) + ) + .scalars() + .all() + ) + + +def upsert_reference( + session: Session, + asset_id: str, + file_path: str, + name: str, + mtime_ns: int, + owner_id: str = "", +) -> tuple[bool, bool]: + """Upsert a reference by file_path. Returns (created, updated). + + Also restores references that were previously marked as missing. + """ + now = get_utc_now() + vals = { + "asset_id": asset_id, + "file_path": file_path, + "name": name, + "owner_id": owner_id, + "mtime_ns": int(mtime_ns), + "is_missing": False, + "created_at": now, + "updated_at": now, + "last_access_time": now, + } + ins = ( + sqlite.insert(AssetReference) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetReference.file_path]) + ) + res = session.execute(ins) + created = int(res.rowcount or 0) > 0 + + if created: + return True, False + + upd = ( + sa.update(AssetReference) + .where(AssetReference.file_path == file_path) + .where( + sa.or_( + AssetReference.asset_id != asset_id, + AssetReference.mtime_ns.is_(None), + AssetReference.mtime_ns != int(mtime_ns), + AssetReference.is_missing == True, # noqa: E712 + AssetReference.deleted_at.isnot(None), + ) + ) + .values( + asset_id=asset_id, mtime_ns=int(mtime_ns), is_missing=False, + deleted_at=None, updated_at=now, + ) + ) + res2 = session.execute(upd) + updated = int(res2.rowcount or 0) > 0 + return False, updated + + +def mark_references_missing_outside_prefixes( + session: Session, + valid_prefixes: list[str], +) -> int: + """Mark references as missing when file_path doesn't match any valid prefix. + + Returns number of references marked as missing. + """ + if not valid_prefixes: + return 0 + + conds = build_prefix_like_conditions(valid_prefixes) + matches_valid_prefix = sa.or_(*conds) + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(~matches_valid_prefix) + .where(AssetReference.is_missing == False) # noqa: E712 + .values(is_missing=True) + ) + return result.rowcount + + +def restore_references_by_paths(session: Session, file_paths: list[str]) -> int: + """Restore references that were previously marked as missing. + + Returns number of references restored. + """ + if not file_paths: + return 0 + + total = 0 + for chunk in iter_chunks(file_paths, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.file_path.in_(chunk)) + .where(AssetReference.is_missing == True) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .values(is_missing=False) + ) + total += result.rowcount + return total + + +def get_unreferenced_unhashed_asset_ids(session: Session) -> list[str]: + """Get IDs of unhashed assets (hash=None) with no active references. + + An asset is considered unreferenced if it has no references, + or all its references are marked as missing. + + Returns list of asset IDs that are unreferenced. + """ + active_ref_exists = ( + sa.select(sa.literal(1)) + .where(AssetReference.asset_id == Asset.id) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + .correlate(Asset) + .exists() + ) + unreferenced_subq = sa.select(Asset.id).where( + Asset.hash.is_(None), ~active_ref_exists + ) + return [row[0] for row in session.execute(unreferenced_subq).all()] + + +def delete_assets_by_ids(session: Session, asset_ids: list[str]) -> int: + """Delete assets and their references by ID. + + Returns number of assets deleted. + """ + if not asset_ids: + return 0 + total = 0 + for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS): + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id.in_(chunk)) + ) + result = session.execute(sa.delete(Asset).where(Asset.id.in_(chunk))) + total += result.rowcount + return total + + +def get_references_for_prefixes( + session: Session, + prefixes: list[str], + *, + include_missing: bool = False, +) -> list[CacheStateRow]: + """Get all references with file paths matching any of the given prefixes. + + Args: + session: Database session + prefixes: List of absolute directory prefixes to match + include_missing: If False (default), exclude references marked as missing + + Returns: + List of cache state rows with joined asset data + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.file_path, + AssetReference.mtime_ns, + AssetReference.needs_verify, + AssetReference.asset_id, + Asset.hash, + Asset.size_bytes, + ) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + ) + + if not include_missing: + query = query.where(AssetReference.is_missing == False) # noqa: E712 + + rows = session.execute( + query.order_by(AssetReference.asset_id.asc(), AssetReference.id.asc()) + ).all() + + return [ + CacheStateRow( + reference_id=row[0], + file_path=row[1], + mtime_ns=row[2], + needs_verify=row[3], + asset_id=row[4], + asset_hash=row[5], + size_bytes=int(row[6]) if row[6] is not None else None, + ) + for row in rows + ] + + +def bulk_update_needs_verify( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set needs_verify flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(needs_verify=value) + ) + total += result.rowcount + return total + + +def bulk_update_is_missing( + session: Session, reference_ids: list[str], value: bool +) -> int: + """Set is_missing flag for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(chunk)) + .values(is_missing=value) + ) + total += result.rowcount + return total + + +def update_is_missing_by_asset_id( + session: Session, asset_id: str, value: bool +) -> int: + """Set is_missing flag for ALL references belonging to an asset. + + Returns: Number of rows updated + """ + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.deleted_at.is_(None)) + .values(is_missing=value) + ) + return result.rowcount + + +def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int: + """Delete references by their IDs. + + Returns: Number of rows deleted + """ + if not reference_ids: + return 0 + total = 0 + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + sa.delete(AssetReference).where(AssetReference.id.in_(chunk)) + ) + total += result.rowcount + return total + + +def delete_orphaned_seed_asset(session: Session, asset_id: str) -> bool: + """Delete a seed asset (hash is None) and its references. + + Returns: True if asset was deleted, False if not found or has a hash + """ + asset = session.get(Asset, asset_id) + if not asset: + return False + if asset.hash is not None: + return False + session.execute( + sa.delete(AssetReference).where(AssetReference.asset_id == asset_id) + ) + session.delete(asset) + return True + + +class UnenrichedReferenceRow(NamedTuple): + """Row for references needing enrichment.""" + + reference_id: str + asset_id: str + file_path: str + enrichment_level: int + + +def get_unenriched_references( + session: Session, + prefixes: list[str], + max_level: int = 0, + limit: int = 1000, +) -> list[UnenrichedReferenceRow]: + """Get references that need enrichment (enrichment_level <= max_level). + + Args: + session: Database session + prefixes: List of absolute directory prefixes to scan + max_level: Maximum enrichment level to include (0=stubs, 1=metadata done) + limit: Maximum number of rows to return + + Returns: + List of unenriched reference rows with file paths + """ + if not prefixes: + return [] + + conds = build_prefix_like_conditions(prefixes) + + query = ( + sa.select( + AssetReference.id, + AssetReference.asset_id, + AssetReference.file_path, + AssetReference.enrichment_level, + ) + .where(AssetReference.file_path.isnot(None)) + .where(AssetReference.deleted_at.is_(None)) + .where(sa.or_(*conds)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.enrichment_level <= max_level) + .order_by(AssetReference.id.asc()) + .limit(limit) + ) + + rows = session.execute(query).all() + return [ + UnenrichedReferenceRow( + reference_id=row[0], + asset_id=row[1], + file_path=row[2], + enrichment_level=row[3], + ) + for row in rows + ] + + +def bulk_update_enrichment_level( + session: Session, + reference_ids: list[str], + level: int, +) -> int: + """Update enrichment level for multiple references. + + Returns: Number of rows updated + """ + if not reference_ids: + return 0 + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.id.in_(reference_ids)) + .values(enrichment_level=level) + ) + return result.rowcount + + +def bulk_insert_references_ignore_conflicts( + session: Session, + rows: list[dict], +) -> None: + """Bulk insert reference rows with ON CONFLICT DO NOTHING on file_path. + + Each dict should have: id, asset_id, file_path, name, owner_id, mtime_ns, etc. + The is_missing field is automatically set to False for new inserts. + """ + if not rows: + return + enriched_rows = [{**row, "is_missing": False} for row in rows] + ins = sqlite.insert(AssetReference).on_conflict_do_nothing( + index_elements=[AssetReference.file_path] + ) + for chunk in iter_chunks(enriched_rows, calculate_rows_per_statement(14)): + session.execute(ins, chunk) + + +def get_references_by_paths_and_asset_ids( + session: Session, + path_to_asset: dict[str, str], +) -> set[str]: + """Query references to find paths where our asset_id won the insert. + + Args: + path_to_asset: Mapping of file_path -> asset_id we tried to insert + + Returns: + Set of file_paths where our asset_id is present + """ + if not path_to_asset: + return set() + + pairs = list(path_to_asset.items()) + winners: set[str] = set() + + # Each pair uses 2 bind params, so chunk at MAX_BIND_PARAMS // 2 + for chunk in iter_chunks(pairs, MAX_BIND_PARAMS // 2): + pairwise = sa.tuple_(AssetReference.file_path, AssetReference.asset_id).in_( + chunk + ) + result = session.execute( + select(AssetReference.file_path).where(pairwise) + ) + winners.update(result.scalars().all()) + + return winners + + +def get_reference_ids_by_ids( + session: Session, + reference_ids: list[str], +) -> set[str]: + """Query to find which reference IDs exist in the database.""" + if not reference_ids: + return set() + + found: set[str] = set() + for chunk in iter_chunks(reference_ids, MAX_BIND_PARAMS): + result = session.execute( + select(AssetReference.id).where(AssetReference.id.in_(chunk)) + ) + found.update(result.scalars().all()) + return found diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py new file mode 100644 index 000000000..89bb49327 --- /dev/null +++ b/app/assets/database/queries/common.py @@ -0,0 +1,127 @@ +"""Shared utilities for database query modules.""" + +import os +from decimal import Decimal +from typing import Iterable, Sequence + +import sqlalchemy as sa +from sqlalchemy import exists + +from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag +from app.assets.helpers import escape_sql_like_string, normalize_tags + +MAX_BIND_PARAMS = 800 + + +def calculate_rows_per_statement(cols: int) -> int: + """Calculate how many rows can fit in one statement given column count.""" + return max(1, MAX_BIND_PARAMS // max(1, cols)) + + +def iter_chunks(seq, n: int): + """Yield successive n-sized chunks from seq.""" + for i in range(0, len(seq), n): + yield seq[i : i + n] + + +def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]: + """Yield chunks of rows sized to fit within bind param limits.""" + if not rows: + return + yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row)) + + +def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. + + Owner-less rows are visible to everyone. + """ + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetReference.owner_id == "" + return AssetReference.owner_id.in_(["", owner_id]) + + +def build_prefix_like_conditions( + prefixes: list[str], +) -> list[sa.sql.ColumnElement]: + """Build LIKE conditions for matching file paths under directory prefixes.""" + conds = [] + for p in prefixes: + base = os.path.abspath(p) + if not base.endswith(os.sep): + base += os.sep + escaped, esc = escape_sql_like_string(base) + conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) + return conds + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_reference_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + return sa.not_( + sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + ) + ) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetReferenceMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetReferenceMeta.val_str == value) + return _exists_for_pred(key, AssetReferenceMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py new file mode 100644 index 000000000..f4126dba8 --- /dev/null +++ b/app/assets/database/queries/tags.py @@ -0,0 +1,418 @@ +from dataclasses import dataclass +from typing import Iterable, Sequence + +import sqlalchemy as sa +from sqlalchemy import delete, func, select +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session + +from app.assets.database.models import ( + Asset, + AssetReference, + AssetReferenceMeta, + AssetReferenceTag, + Tag, +) +from app.assets.database.queries.common import ( + apply_metadata_filter, + apply_tag_filters, + build_visible_owner_clause, + iter_row_chunks, +) +from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + + +@dataclass(frozen=True) +class AddTagsResult: + added: list[str] + already_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class RemoveTagsResult: + removed: list[str] + not_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class SetTagsResult: + added: list[str] + removed: list[str] + total: list[str] + + +def validate_tags_exist(session: Session, tags: list[str]) -> None: + """Raise ValueError if any of the given tag names do not exist.""" + existing_tag_names = set( + name + for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() + ) + missing = [t for t in tags if t not in existing_tag_names] + if missing: + raise ValueError(f"Unknown tags: {missing}") + + +def ensure_tags_exist( + session: Session, names: Iterable[str], tag_type: str = "user" +) -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + session.execute(ins) + + +def get_reference_tags(session: Session, reference_id: str) -> list[str]: + return [ + tag_name + for (tag_name,) in ( + session.execute( + select(AssetReferenceTag.tag_name) + .where(AssetReferenceTag.asset_reference_id == reference_id) + .order_by(AssetReferenceTag.tag_name.asc()) + ) + ).all() + ] + + +def set_reference_tags( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> SetTagsResult: + desired = normalize_tags(tags) + + current = set(get_reference_tags(session, reference_id)) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + ensure_tags_exist(session, to_add, tag_type="user") + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired)) + + +def add_tags_to_reference( + session: Session, + reference_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + reference_row: AssetReference | None = None, +) -> AddTagsResult: + if not reference_row: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return AddTagsResult(added=[], already_present=[], total_tags=total) + + if create_if_missing: + ensure_tags_exist(session, norm, tag_type="user") + + current = set(get_reference_tags(session, reference_id)) + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetReferenceTag( + asset_reference_id=reference_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) + session.flush() + except IntegrityError: + nested.rollback() + + after = set(get_reference_tags(session, reference_id=reference_id)) + return AddTagsResult( + added=sorted(((after - current) & want)), + already_present=sorted(want & current), + total_tags=sorted(after), + ) + + +def remove_tags_from_reference( + session: Session, + reference_id: str, + tags: Sequence[str], +) -> RemoveTagsResult: + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=[], not_present=[], total_tags=total) + + existing = set(get_reference_tags(session, reference_id)) + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + session.execute( + delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id == reference_id, + AssetReferenceTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + total = get_reference_tags(session, reference_id=reference_id) + return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total) + + +def add_missing_tag_for_asset_id( + session: Session, + asset_id: str, + origin: str = "automatic", +) -> None: + select_rows = ( + sa.select( + AssetReference.id.label("asset_reference_id"), + sa.literal("missing").label("tag_name"), + sa.literal(origin).label("origin"), + sa.literal(get_utc_now()).label("added_at"), + ) + .where(AssetReference.asset_id == asset_id) + .where( + sa.not_( + sa.exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == "missing") + ) + ) + ) + ) + session.execute( + sqlite.insert(AssetReferenceTag) + .from_select( + ["asset_reference_id", "tag_name", "origin", "added_at"], + select_rows, + ) + .on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + ) + + +def remove_missing_tag_for_asset_id( + session: Session, + asset_id: str, +) -> None: + session.execute( + sa.delete(AssetReferenceTag).where( + AssetReferenceTag.asset_reference_id.in_( + sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id) + ), + AssetReferenceTag.tag_name == "missing", + ) + ) + + +def list_tags_with_usage( + session: Session, + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + include_zero: bool = True, + order: str = "count_desc", + owner_id: str = "", +) -> tuple[list[tuple[str, str, int]], int]: + counts_sq = ( + select( + AssetReferenceTag.tag_name.label("tag_name"), + func.count(AssetReferenceTag.asset_reference_id).label("cnt"), + ) + .select_from(AssetReferenceTag) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where( + sa.or_( + AssetReference.is_missing == False, # noqa: E712 + AssetReferenceTag.tag_name == "missing", + ) + ) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + .subquery() + ) + + q = ( + select( + Tag.name, + Tag.tag_type, + func.coalesce(counts_sq.c.cnt, 0).label("count"), + ) + .select_from(Tag) + .join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True) + ) + + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) + + if not include_zero: + q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) + + if order == "name_asc": + q = q.order_by(Tag.name.asc()) + else: + q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc()) + + total_q = select(func.count()).select_from(Tag) + if prefix: + escaped, esc = escape_sql_like_string(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) + if not include_zero: + visible_tags_sq = ( + select(AssetReferenceTag.tag_name) + .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) + .where(build_visible_owner_clause(owner_id)) + .where( + sa.or_( + AssetReference.is_missing == False, # noqa: E712 + AssetReferenceTag.tag_name == "missing", + ) + ) + .where(AssetReference.deleted_at.is_(None)) + .group_by(AssetReferenceTag.tag_name) + ) + total_q = total_q.where(Tag.name.in_(visible_tags_sq)) + + rows = (session.execute(q.limit(limit).offset(offset))).all() + total = (session.execute(total_q)).scalar_one() + + rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] + return rows_norm, int(total or 0) + + +def list_tag_counts_for_filtered_assets( + session: Session, + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 100, +) -> dict[str, int]: + """Return tag counts for assets matching the given filters. + + Uses the same filtering logic as list_references_page but returns + {tag_name: count} instead of paginated references. + """ + # Build a subquery of matching reference IDs + ref_sq = ( + select(AssetReference.id) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + ) + + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + + ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags) + ref_sq = apply_metadata_filter(ref_sq, metadata_filter) + ref_sq = ref_sq.subquery() + + # Count tags across those references + q = ( + select( + AssetReferenceTag.tag_name, + func.count(AssetReferenceTag.asset_reference_id).label("cnt"), + ) + .where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id))) + .group_by(AssetReferenceTag.tag_name) + .order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc()) + .limit(limit) + ) + + rows = session.execute(q).all() + return {tag_name: int(cnt) for tag_name, cnt in rows} + + +def bulk_insert_tags_and_meta( + session: Session, + tag_rows: list[dict], + meta_rows: list[dict], +) -> None: + """Batch insert into asset_reference_tags and asset_reference_meta. + + Uses ON CONFLICT DO NOTHING. + + Args: + session: Database session + tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at + meta_rows: Dicts with: asset_reference_id, key, ordinal, val_* + """ + if tag_rows: + ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing( + index_elements=[ + AssetReferenceTag.asset_reference_id, + AssetReferenceTag.tag_name, + ] + ) + for chunk in iter_row_chunks(tag_rows, cols_per_row=4): + session.execute(ins_tags, chunk) + + if meta_rows: + ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing( + index_elements=[ + AssetReferenceMeta.asset_reference_id, + AssetReferenceMeta.key, + AssetReferenceMeta.ordinal, + ] + ) + for chunk in iter_row_chunks(meta_rows, cols_per_row=7): + session.execute(ins_meta, chunk) diff --git a/app/assets/helpers.py b/app/assets/helpers.py new file mode 100644 index 000000000..3798f3933 --- /dev/null +++ b/app/assets/helpers.py @@ -0,0 +1,65 @@ +import os +from datetime import datetime, timezone +from typing import Sequence + + +def select_best_live_path(states: Sequence) -> str: + """ + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. + """ + alive = [ + s + for s in states + if getattr(s, "file_path", None) and os.path.isfile(s.file_path) + ] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path + + +def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char in a LIKE prefix. + + Returns (escaped_prefix, escape_char). + """ + s = s.replace(escape, escape + escape) # escape the escape char first + s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards + return s, escape + + +def get_utc_now() -> datetime: + """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" + return datetime.now(timezone.utc).replace(tzinfo=None) + + +def normalize_tags(tags: list[str] | None) -> list[str]: + """ + Normalize a list of tags by: + - Stripping whitespace and converting to lowercase. + - Removing duplicates. + """ + return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip())) + + +def validate_blake3_hash(s: str) -> str: + """Validate and normalize a blake3 hash string. + + Returns canonical 'blake3:' or raises ValueError. + """ + s = s.strip().lower() + if not s or ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if ( + algo != "blake3" + or len(digest) != 64 + or any(c for c in digest if c not in "0123456789abcdef") + ): + raise ValueError("hash must be 'blake3:'") + return f"{algo}:{digest}" diff --git a/app/assets/scanner.py b/app/assets/scanner.py new file mode 100644 index 000000000..ebb6869af --- /dev/null +++ b/app/assets/scanner.py @@ -0,0 +1,582 @@ +import logging +import os +from pathlib import Path +from typing import Callable, Literal, TypedDict + +import folder_paths +from app.assets.database.queries import ( + add_missing_tag_for_asset_id, + bulk_update_enrichment_level, + bulk_update_is_missing, + bulk_update_needs_verify, + delete_orphaned_seed_asset, + delete_references_by_ids, + ensure_tags_exist, + get_asset_by_hash, + get_reference_by_id, + get_references_for_prefixes, + get_unenriched_references, + mark_references_missing_outside_prefixes, + reassign_asset_references, + remove_missing_tag_for_asset_id, + set_reference_system_metadata, + update_asset_hash_and_mime, +) +from app.assets.services.bulk_ingest import ( + SeedAssetSpec, + batch_insert_seed_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + is_visible, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash +from app.assets.services.metadata_extract import extract_file_metadata +from app.assets.services.path_utils import ( + compute_relative_filename, + get_comfy_models_folders, + get_name_and_tags_from_asset_path, +) +from app.database.db import create_session + + +class _RefInfo(TypedDict): + ref_id: str + file_path: str + exists: bool + stat_unchanged: bool + needs_verify: bool + + +class _AssetAccumulator(TypedDict): + hash: str | None + size_db: int + refs: list[_RefInfo] + + +RootType = Literal["models", "input", "output"] + + +def get_prefixes_for_root(root: RootType) -> list[str]: + if root == "models": + bases: list[str] = [] + for _bucket, paths in get_comfy_models_folders(): + bases.extend(paths) + return [os.path.abspath(p) for p in bases] + if root == "input": + return [os.path.abspath(folder_paths.get_input_directory())] + if root == "output": + return [os.path.abspath(folder_paths.get_output_directory())] + return [] + + +def get_all_known_prefixes() -> list[str]: + """Get all known asset prefixes across all root types.""" + all_roots: tuple[RootType, ...] = ("models", "input", "output") + return [p for root in all_roots for p in get_prefixes_for_root(root)] + + +def collect_models_files() -> list[str]: + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): + rel_files = folder_paths.get_filename_list(folder_name) or [] + for rel_path in rel_files: + if not all(is_visible(part) for part in Path(rel_path).parts): + continue + abs_path = folder_paths.get_full_path(folder_name, rel_path) + if not abs_path: + continue + abs_path = os.path.abspath(abs_path) + allowed = False + abs_p = Path(abs_path) + for b in bases: + if abs_p.is_relative_to(os.path.abspath(b)): + allowed = True + break + if allowed: + out.append(abs_path) + return out + + +def sync_references_with_filesystem( + session, + root: RootType, + collect_existing_paths: bool = False, + update_missing_tags: bool = False, +) -> set[str] | None: + """Reconcile asset references with filesystem for a root. + + - Toggle needs_verify per reference using mtime/size stat check + - For hashed assets with at least one stat-unchanged ref: delete stale missing refs + - For seed assets with all refs missing: delete Asset and its references + - Optionally add/remove 'missing' tags based on stat check in this root + - Optionally return surviving absolute paths + + Args: + session: Database session + root: Root type to scan + collect_existing_paths: If True, return set of surviving file paths + update_missing_tags: If True, update 'missing' tags based on file status + + Returns: + Set of surviving absolute paths if collect_existing_paths=True, else None + """ + prefixes = get_prefixes_for_root(root) + if not prefixes: + return set() if collect_existing_paths else None + + rows = get_references_for_prefixes( + session, prefixes, include_missing=update_missing_tags + ) + + by_asset: dict[str, _AssetAccumulator] = {} + for row in rows: + acc = by_asset.get(row.asset_id) + if acc is None: + acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []} + by_asset[row.asset_id] = acc + + stat_unchanged = False + try: + exists = True + stat_unchanged = verify_file_unchanged( + mtime_db=row.mtime_ns, + size_db=acc["size_db"], + stat_result=os.stat(row.file_path, follow_symlinks=True), + ) + except FileNotFoundError: + exists = False + except PermissionError: + exists = True + logging.debug("Permission denied accessing %s", row.file_path) + except OSError as e: + exists = False + logging.debug("OSError checking %s: %s", row.file_path, e) + + acc["refs"].append( + { + "ref_id": row.reference_id, + "file_path": row.file_path, + "exists": exists, + "stat_unchanged": stat_unchanged, + "needs_verify": row.needs_verify, + } + ) + + to_set_verify: list[str] = [] + to_clear_verify: list[str] = [] + stale_ref_ids: list[str] = [] + to_mark_missing: list[str] = [] + to_clear_missing: list[str] = [] + survivors: set[str] = set() + + for aid, acc in by_asset.items(): + a_hash = acc["hash"] + refs = acc["refs"] + any_unchanged = any(r["stat_unchanged"] for r in refs) + all_missing = all(not r["exists"] for r in refs) + + for r in refs: + if not r["exists"]: + to_mark_missing.append(r["ref_id"]) + continue + if r["stat_unchanged"]: + to_clear_missing.append(r["ref_id"]) + if r["needs_verify"]: + to_clear_verify.append(r["ref_id"]) + if not r["stat_unchanged"] and not r["needs_verify"]: + to_set_verify.append(r["ref_id"]) + + if a_hash is None: + if refs and all_missing: + delete_orphaned_seed_asset(session, aid) + else: + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + continue + + if any_unchanged: + for r in refs: + if not r["exists"]: + stale_ref_ids.append(r["ref_id"]) + if update_missing_tags: + try: + remove_missing_tag_for_asset_id(session, asset_id=aid) + except Exception as e: + logging.warning( + "Failed to remove missing tag for asset %s: %s", aid, e + ) + elif update_missing_tags: + try: + add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic") + except Exception as e: + logging.warning("Failed to add missing tag for asset %s: %s", aid, e) + + for r in refs: + if r["exists"]: + survivors.add(os.path.abspath(r["file_path"])) + + delete_references_by_ids(session, stale_ref_ids) + stale_set = set(stale_ref_ids) + to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set] + bulk_update_is_missing(session, to_mark_missing, value=True) + bulk_update_is_missing(session, to_clear_missing, value=False) + bulk_update_needs_verify(session, to_set_verify, value=True) + bulk_update_needs_verify(session, to_clear_verify, value=False) + + return survivors if collect_existing_paths else None + + +def sync_root_safely(root: RootType) -> set[str]: + """Sync a single root's references with the filesystem. + + Returns survivors (existing paths) or empty set on failure. + """ + try: + with create_session() as sess: + survivors = sync_references_with_filesystem( + sess, + root, + collect_existing_paths=True, + update_missing_tags=True, + ) + sess.commit() + return survivors or set() + except Exception as e: + logging.exception("fast DB scan failed for %s: %s", root, e) + return set() + + +def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int: + """Mark references as missing when outside the given prefixes. + + This is a non-destructive soft-delete. Returns count marked or 0 on failure. + """ + try: + with create_session() as sess: + count = mark_references_missing_outside_prefixes(sess, prefixes) + sess.commit() + return count + except Exception as e: + logging.exception("marking missing assets failed: %s", e) + return 0 + + +def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]: + """Collect all file paths for the given roots.""" + paths: list[str] = [] + if "models" in roots: + paths.extend(collect_models_files()) + if "input" in roots: + paths.extend(list_files_recursively(folder_paths.get_input_directory())) + if "output" in roots: + paths.extend(list_files_recursively(folder_paths.get_output_directory())) + return paths + + +def build_asset_specs( + paths: list[str], + existing_paths: set[str], + enable_metadata_extraction: bool = True, + compute_hashes: bool = False, +) -> tuple[list[SeedAssetSpec], set[str], int]: + """Build asset specs from paths, returning (specs, tag_pool, skipped_count). + + Args: + paths: List of file paths to process + existing_paths: Set of paths that already exist in the database + enable_metadata_extraction: If True, extract tier 1 & 2 metadata + compute_hashes: If True, compute blake3 hashes (slow for large files) + """ + specs: list[SeedAssetSpec] = [] + tag_pool: set[str] = set() + skipped = 0 + + for p in paths: + abs_p = os.path.abspath(p) + if abs_p in existing_paths: + skipped += 1 + continue + try: + stat_p = os.stat(abs_p, follow_symlinks=True) + except OSError: + continue + if not stat_p.st_size: + continue + name, tags = get_name_and_tags_from_asset_path(abs_p) + rel_fname = compute_relative_filename(abs_p) + + # Extract metadata (tier 1: filesystem, tier 2: safetensors header) + metadata = None + if enable_metadata_extraction: + metadata = extract_file_metadata( + abs_p, + stat_result=stat_p, + relative_filename=rel_fname, + ) + + # Compute hash if requested + asset_hash: str | None = None + if compute_hashes: + try: + digest, _ = compute_blake3_hash(abs_p) + asset_hash = "blake3:" + digest + except Exception as e: + logging.warning("Failed to hash %s: %s", abs_p, e) + + mime_type = metadata.content_type if metadata else None + specs.append( + { + "abs_path": abs_p, + "size_bytes": stat_p.st_size, + "mtime_ns": get_mtime_ns(stat_p), + "info_name": name, + "tags": tags, + "fname": rel_fname, + "metadata": metadata, + "hash": asset_hash, + "mime_type": mime_type, + "job_id": None, + } + ) + tag_pool.update(tags) + + return specs, tag_pool, skipped + + + +def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int: + """Insert asset specs into database, returning count of created refs.""" + if not specs: + return 0 + with create_session() as sess: + if tag_pool: + ensure_tags_exist(sess, tag_pool, tag_type="user") + result = batch_insert_seed_assets(sess, specs=specs, owner_id="") + sess.commit() + return result.inserted_refs + + +# Enrichment level constants +ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only +ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type) +ENRICHMENT_HASHED = 2 # Hash computed (blake3) + + +def get_unenriched_assets_for_roots( + roots: tuple[RootType, ...], + max_level: int = ENRICHMENT_STUB, + limit: int = 1000, +) -> list: + """Get assets that need enrichment for the given roots. + + Args: + roots: Tuple of root types to scan + max_level: Maximum enrichment level to include + limit: Maximum number of rows to return + + Returns: + List of UnenrichedReferenceRow + """ + prefixes: list[str] = [] + for root in roots: + prefixes.extend(get_prefixes_for_root(root)) + + if not prefixes: + return [] + + with create_session() as sess: + return get_unenriched_references( + sess, prefixes, max_level=max_level, limit=limit + ) + + +def enrich_asset( + session, + file_path: str, + reference_id: str, + asset_id: str, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> int: + """Enrich a single asset with metadata and/or hash. + + Args: + session: Database session (caller manages lifecycle) + file_path: Absolute path to the file + reference_id: ID of the reference to update + asset_id: ID of the asset to update (for mime_type and hash) + extract_metadata: If True, extract safetensors header and mime type + compute_hash: If True, compute blake3 hash + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + New enrichment level achieved + """ + new_level = ENRICHMENT_STUB + + try: + stat_p = os.stat(file_path, follow_symlinks=True) + except OSError: + return new_level + + initial_mtime_ns = get_mtime_ns(stat_p) + rel_fname = compute_relative_filename(file_path) + mime_type: str | None = None + metadata = None + + if extract_metadata: + metadata = extract_file_metadata( + file_path, + stat_result=stat_p, + relative_filename=rel_fname, + ) + if metadata: + mime_type = metadata.content_type + new_level = ENRICHMENT_METADATA + + full_hash: str | None = None + if compute_hash: + try: + mtime_before = get_mtime_ns(stat_p) + size_before = stat_p.st_size + + # Restore checkpoint if available and file unchanged + checkpoint = None + if hash_checkpoints is not None: + checkpoint = hash_checkpoints.get(file_path) + if checkpoint is not None: + cur_stat = os.stat(file_path, follow_symlinks=True) + if (checkpoint.mtime_ns != get_mtime_ns(cur_stat) + or checkpoint.file_size != cur_stat.st_size): + checkpoint = None + hash_checkpoints.pop(file_path, None) + else: + mtime_before = get_mtime_ns(cur_stat) + + digest, new_checkpoint = compute_blake3_hash( + file_path, + interrupt_check=interrupt_check, + checkpoint=checkpoint, + ) + + if digest is None: + # Interrupted — save checkpoint for later resumption + if hash_checkpoints is not None and new_checkpoint is not None: + new_checkpoint.mtime_ns = mtime_before + new_checkpoint.file_size = size_before + hash_checkpoints[file_path] = new_checkpoint + return new_level + + # Completed — clear any saved checkpoint + if hash_checkpoints is not None: + hash_checkpoints.pop(file_path, None) + + stat_after = os.stat(file_path, follow_symlinks=True) + mtime_after = get_mtime_ns(stat_after) + if mtime_before != mtime_after: + logging.warning("File modified during hashing, discarding hash: %s", file_path) + else: + full_hash = f"blake3:{digest}" + metadata_ok = not extract_metadata or metadata is not None + if metadata_ok: + new_level = ENRICHMENT_HASHED + except Exception as e: + logging.warning("Failed to hash %s: %s", file_path, e) + + # Optimistic guard: if the reference's mtime_ns changed since we + # started (e.g. ingest_existing_file updated it), our results are + # stale — discard them to avoid overwriting fresh registration data. + ref = get_reference_by_id(session, reference_id) + if ref is None or ref.mtime_ns != initial_mtime_ns: + session.rollback() + logging.info( + "Ref %s mtime changed during enrichment, discarding stale result", + reference_id, + ) + return ENRICHMENT_STUB + + if extract_metadata and metadata: + system_metadata = metadata.to_user_metadata() + set_reference_system_metadata(session, reference_id, system_metadata) + + if full_hash: + existing = get_asset_by_hash(session, full_hash) + if existing and existing.id != asset_id: + reassign_asset_references(session, asset_id, existing.id, reference_id) + delete_orphaned_seed_asset(session, asset_id) + if mime_type: + update_asset_hash_and_mime(session, existing.id, mime_type=mime_type) + else: + update_asset_hash_and_mime(session, asset_id, full_hash, mime_type) + elif mime_type: + update_asset_hash_and_mime(session, asset_id, mime_type=mime_type) + + bulk_update_enrichment_level(session, [reference_id], new_level) + session.commit() + + return new_level + + +def enrich_assets_batch( + rows: list, + extract_metadata: bool = True, + compute_hash: bool = False, + interrupt_check: Callable[[], bool] | None = None, + hash_checkpoints: dict[str, HashCheckpoint] | None = None, +) -> tuple[int, list[str]]: + """Enrich a batch of assets. + + Uses a single DB session for the entire batch, committing after each + individual asset to avoid long-held transactions while eliminating + per-asset session creation overhead. + + Args: + rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots + extract_metadata: If True, extract metadata for each asset + compute_hash: If True, compute hash for each asset + interrupt_check: Optional non-blocking callable that returns True if + the operation should be interrupted (e.g. paused or cancelled) + hash_checkpoints: Optional dict for saving/restoring hash progress + across interruptions, keyed by file path + + Returns: + Tuple of (enriched_count, failed_reference_ids) + """ + enriched = 0 + failed_ids: list[str] = [] + + with create_session() as sess: + for row in rows: + if interrupt_check is not None and interrupt_check(): + break + + try: + new_level = enrich_asset( + sess, + file_path=row.file_path, + reference_id=row.reference_id, + asset_id=row.asset_id, + extract_metadata=extract_metadata, + compute_hash=compute_hash, + interrupt_check=interrupt_check, + hash_checkpoints=hash_checkpoints, + ) + if new_level > row.enrichment_level: + enriched += 1 + else: + failed_ids.append(row.reference_id) + except Exception as e: + logging.warning("Failed to enrich %s: %s", row.file_path, e) + sess.rollback() + failed_ids.append(row.reference_id) + + return enriched, failed_ids diff --git a/app/assets/seeder.py b/app/assets/seeder.py new file mode 100644 index 000000000..2262928e5 --- /dev/null +++ b/app/assets/seeder.py @@ -0,0 +1,846 @@ +"""Background asset seeder with thread management and cancellation support.""" + +import logging +import os +import threading +import time +from dataclasses import dataclass, field +from enum import Enum +from typing import Callable + +from app.assets.scanner import ( + ENRICHMENT_METADATA, + ENRICHMENT_STUB, + RootType, + build_asset_specs, + collect_paths_for_roots, + enrich_assets_batch, + get_all_known_prefixes, + get_prefixes_for_root, + get_unenriched_assets_for_roots, + insert_asset_specs, + mark_missing_outside_prefixes_safely, + sync_root_safely, +) +from app.database.db import dependencies_available + + +class ScanInProgressError(Exception): + """Raised when an operation cannot proceed because a scan is running.""" + + +class State(Enum): + """Seeder state machine states.""" + + IDLE = "IDLE" + RUNNING = "RUNNING" + PAUSED = "PAUSED" + CANCELLING = "CANCELLING" + + +class ScanPhase(Enum): + """Scan phase options.""" + + FAST = "fast" # Phase 1: filesystem only (stubs) + ENRICH = "enrich" # Phase 2: metadata + hash + FULL = "full" # Both phases sequentially + + +@dataclass +class Progress: + """Progress information for a scan operation.""" + + scanned: int = 0 + total: int = 0 + created: int = 0 + skipped: int = 0 + + +@dataclass +class ScanStatus: + """Current status of the asset seeder.""" + + state: State + progress: Progress | None + errors: list[str] = field(default_factory=list) + + +ProgressCallback = Callable[[Progress], None] + + +class _AssetSeeder: + """Background asset scanning manager. + + Spawns ephemeral daemon threads for scanning. + Each scan creates a new thread that exits when complete. + Use the module-level ``asset_seeder`` instance. + """ + + def __init__(self) -> None: + # RLock is required because _run_scan() drains pending work while + # holding _lock and re-enters start() which also acquires _lock. + self._lock = threading.RLock() + self._state = State.IDLE + self._progress: Progress | None = None + self._last_progress: Progress | None = None + self._errors: list[str] = [] + self._thread: threading.Thread | None = None + self._cancel_event = threading.Event() + self._run_gate = threading.Event() + self._run_gate.set() # Start unpaused (set = running, clear = paused) + self._roots: tuple[RootType, ...] = () + self._phase: ScanPhase = ScanPhase.FULL + self._compute_hashes: bool = False + self._prune_first: bool = False + self._progress_callback: ProgressCallback | None = None + self._disabled: bool = False + self._pending_enrich: dict | None = None + + def disable(self) -> None: + """Disable the asset seeder, preventing any scans from starting.""" + self._disabled = True + logging.info("Asset seeder disabled") + + def is_disabled(self) -> bool: + """Check if the asset seeder is disabled.""" + return self._disabled + + def start( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + phase: ScanPhase = ScanPhase.FULL, + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + compute_hashes: bool = False, + ) -> bool: + """Start a background scan for the given roots. + + Args: + roots: Tuple of root types to scan (models, input, output) + phase: Scan phase to run (FAST, ENRICH, or FULL for both) + progress_callback: Optional callback called with progress updates + prune_first: If True, prune orphaned assets before scanning + compute_hashes: If True, compute blake3 hashes (slow) + + Returns: + True if scan was started, False if already running + """ + if self._disabled: + logging.debug("Asset seeder is disabled, skipping start") + return False + logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value) + with self._lock: + if self._state != State.IDLE: + logging.info("Asset seeder already running, skipping start") + return False + self._state = State.RUNNING + self._progress = Progress() + self._errors = [] + self._roots = roots + self._phase = phase + self._prune_first = prune_first + self._compute_hashes = compute_hashes + self._progress_callback = progress_callback + self._cancel_event.clear() + self._run_gate.set() # Ensure unpaused when starting + self._thread = threading.Thread( + target=self._run_scan, + name="_AssetSeeder", + daemon=True, + ) + self._thread.start() + return True + + def start_fast( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + prune_first: bool = False, + ) -> bool: + """Start a fast scan (phase 1 only) - creates stub records. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + prune_first: If True, prune orphaned assets before scanning + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.FAST, + progress_callback=progress_callback, + prune_first=prune_first, + compute_hashes=False, + ) + + def start_enrich( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + progress_callback: ProgressCallback | None = None, + compute_hashes: bool = False, + ) -> bool: + """Start an enrichment scan (phase 2 only) - extracts metadata and hashes. + + Args: + roots: Tuple of root types to scan + progress_callback: Optional callback for progress updates + compute_hashes: If True, compute blake3 hashes + + Returns: + True if scan was started, False if already running + """ + return self.start( + roots=roots, + phase=ScanPhase.ENRICH, + progress_callback=progress_callback, + prune_first=False, + compute_hashes=compute_hashes, + ) + + def enqueue_enrich( + self, + roots: tuple[RootType, ...] = ("models", "input", "output"), + compute_hashes: bool = False, + ) -> bool: + """Start an enrichment scan now, or queue it for after the current scan. + + If the seeder is idle, starts immediately. Otherwise, the enrich + request is stored and will run automatically when the current scan + finishes. + + Args: + roots: Tuple of root types to scan + compute_hashes: If True, compute blake3 hashes + + Returns: + True if started immediately, False if queued for later + """ + with self._lock: + if self.start_enrich(roots=roots, compute_hashes=compute_hashes): + return True + if self._pending_enrich is not None: + existing_roots = set(self._pending_enrich["roots"]) + existing_roots.update(roots) + self._pending_enrich["roots"] = tuple(existing_roots) + self._pending_enrich["compute_hashes"] = ( + self._pending_enrich["compute_hashes"] or compute_hashes + ) + else: + self._pending_enrich = { + "roots": roots, + "compute_hashes": compute_hashes, + } + logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"]) + return False + + def cancel(self) -> bool: + """Request cancellation of the current scan. + + Returns: + True if cancellation was requested, False if not running or paused + """ + with self._lock: + if self._state not in (State.RUNNING, State.PAUSED): + return False + logging.info("Asset seeder cancelling (was %s)", self._state.value) + self._state = State.CANCELLING + self._cancel_event.set() + self._run_gate.set() # Unblock if paused so thread can exit + return True + + def stop(self) -> bool: + """Stop the current scan (alias for cancel). + + Returns: + True if stop was requested, False if not running + """ + return self.cancel() + + def pause(self) -> bool: + """Pause the current scan. + + The scan will complete its current batch before pausing. + + Returns: + True if pause was requested, False if not running + """ + with self._lock: + if self._state != State.RUNNING: + return False + logging.info("Asset seeder pausing") + self._state = State.PAUSED + self._run_gate.clear() + return True + + def resume(self) -> bool: + """Resume a paused scan. + + This is a noop if the scan is not in the PAUSED state + + Returns: + True if resumed, False if not paused + """ + with self._lock: + if self._state != State.PAUSED: + return False + logging.info("Asset seeder resuming") + self._state = State.RUNNING + self._run_gate.set() + self._emit_event("assets.seed.resumed", {}) + return True + + def restart( + self, + roots: tuple[RootType, ...] | None = None, + phase: ScanPhase | None = None, + progress_callback: ProgressCallback | None = None, + prune_first: bool | None = None, + compute_hashes: bool | None = None, + timeout: float = 5.0, + ) -> bool: + """Cancel any running scan and start a new one. + + Args: + roots: Roots to scan (defaults to previous roots) + phase: Scan phase (defaults to previous phase) + progress_callback: Progress callback (defaults to previous) + prune_first: Prune before scan (defaults to previous) + compute_hashes: Compute hashes (defaults to previous) + timeout: Max seconds to wait for current scan to stop + + Returns: + True if new scan was started, False if failed to stop previous + """ + logging.info("Asset seeder restart requested") + with self._lock: + prev_roots = self._roots + prev_phase = self._phase + prev_callback = self._progress_callback + prev_prune = self._prune_first + prev_hashes = self._compute_hashes + + self.cancel() + if not self.wait(timeout=timeout): + return False + + cb = progress_callback if progress_callback is not None else prev_callback + return self.start( + roots=roots if roots is not None else prev_roots, + phase=phase if phase is not None else prev_phase, + progress_callback=cb, + prune_first=prune_first if prune_first is not None else prev_prune, + compute_hashes=( + compute_hashes if compute_hashes is not None else prev_hashes + ), + ) + + def wait(self, timeout: float | None = None) -> bool: + """Wait for the current scan to complete. + + Args: + timeout: Maximum seconds to wait, or None for no timeout + + Returns: + True if scan completed, False if timeout expired or no scan running + """ + with self._lock: + thread = self._thread + if thread is None: + return True + thread.join(timeout=timeout) + return not thread.is_alive() + + def get_status(self) -> ScanStatus: + """Get the current status and progress of the seeder.""" + with self._lock: + src = self._progress or self._last_progress + return ScanStatus( + state=self._state, + progress=Progress( + scanned=src.scanned, + total=src.total, + created=src.created, + skipped=src.skipped, + ) + if src + else None, + errors=list(self._errors), + ) + + def shutdown(self, timeout: float = 5.0) -> None: + """Gracefully shutdown: cancel any running scan and wait for thread. + + Args: + timeout: Maximum seconds to wait for thread to exit + """ + self.cancel() + self.wait(timeout=timeout) + with self._lock: + self._thread = None + + def mark_missing_outside_prefixes(self) -> int: + """Mark references as missing when outside all known root prefixes. + + This is a non-destructive soft-delete operation. Assets and their + metadata are preserved, but references are flagged as missing. + They can be restored if the file reappears in a future scan. + + This operation is decoupled from scanning to prevent partial scans + from accidentally marking assets belonging to other roots. + + Should be called explicitly when cleanup is desired, typically after + a full scan of all roots or during maintenance. + + Returns: + Number of references marked as missing + + Raises: + ScanInProgressError: If a scan is currently running + """ + with self._lock: + if self._state != State.IDLE: + raise ScanInProgressError( + "Cannot mark missing assets while scan is running" + ) + self._state = State.RUNNING + + try: + if not dependencies_available(): + logging.warning( + "Database dependencies not available, skipping mark missing" + ) + return 0 + + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d references as missing", marked) + return marked + finally: + with self._lock: + self._reset_to_idle() + + def _reset_to_idle(self) -> None: + """Reset state to IDLE, preserving last progress. Caller must hold _lock.""" + self._last_progress = self._progress + self._state = State.IDLE + self._progress = None + + def _is_cancelled(self) -> bool: + """Check if cancellation has been requested.""" + return self._cancel_event.is_set() + + def _is_paused_or_cancelled(self) -> bool: + """Non-blocking check: True if paused or cancelled. + + Use as interrupt_check for I/O-bound work (e.g. hashing) so that + file handles are released immediately on pause rather than held + open while blocked. The caller is responsible for blocking on + _check_pause_and_cancel() afterward. + """ + return not self._run_gate.is_set() or self._cancel_event.is_set() + + def _check_pause_and_cancel(self) -> bool: + """Block while paused, then check if cancelled. + + Call this at checkpoint locations in scan loops. It will: + 1. Block indefinitely while paused (until resume or cancel) + 2. Return True if cancelled, False to continue + + Returns: + True if scan should stop, False to continue + """ + if not self._run_gate.is_set(): + self._emit_event("assets.seed.paused", {}) + self._run_gate.wait() # Blocks if paused + return self._is_cancelled() + + def _emit_event(self, event_type: str, data: dict) -> None: + """Emit a WebSocket event if server is available.""" + try: + from server import PromptServer + + if hasattr(PromptServer, "instance") and PromptServer.instance: + PromptServer.instance.send_sync(event_type, data) + except Exception: + pass + + def _update_progress( + self, + scanned: int | None = None, + total: int | None = None, + created: int | None = None, + skipped: int | None = None, + ) -> None: + """Update progress counters (thread-safe).""" + callback: ProgressCallback | None = None + progress: Progress | None = None + + with self._lock: + if self._progress is None: + return + if scanned is not None: + self._progress.scanned = scanned + if total is not None: + self._progress.total = total + if created is not None: + self._progress.created = created + if skipped is not None: + self._progress.skipped = skipped + if self._progress_callback: + callback = self._progress_callback + progress = Progress( + scanned=self._progress.scanned, + total=self._progress.total, + created=self._progress.created, + skipped=self._progress.skipped, + ) + + if callback and progress: + try: + callback(progress) + except Exception: + pass + + _MAX_ERRORS = 200 + + def _add_error(self, message: str) -> None: + """Add an error message (thread-safe), capped at _MAX_ERRORS.""" + with self._lock: + if len(self._errors) < self._MAX_ERRORS: + self._errors.append(message) + + def _log_scan_config(self, roots: tuple[RootType, ...]) -> None: + """Log the directories that will be scanned.""" + import folder_paths + + for root in roots: + if root == "models": + logging.info( + "Asset scan [models] directory: %s", + os.path.abspath(folder_paths.models_dir), + ) + else: + prefixes = get_prefixes_for_root(root) + if prefixes: + logging.info("Asset scan [%s] directories: %s", root, prefixes) + + def _run_scan(self) -> None: + """Main scan loop running in background thread.""" + t_start = time.perf_counter() + roots = self._roots + phase = self._phase + cancelled = False + total_created = 0 + total_enriched = 0 + skipped_existing = 0 + total_paths = 0 + + try: + if not dependencies_available(): + self._add_error("Database dependencies not available") + self._emit_event( + "assets.seed.error", + {"message": "Database dependencies not available"}, + ) + return + + if self._prune_first: + all_prefixes = get_all_known_prefixes() + marked = mark_missing_outside_prefixes_safely(all_prefixes) + if marked > 0: + logging.info("Marked %d refs as missing before scan", marked) + + if self._check_pause_and_cancel(): + logging.info("Asset scan cancelled after pruning phase") + cancelled = True + return + + self._log_scan_config(roots) + + # Phase 1: Fast scan (stub records) + if phase in (ScanPhase.FAST, ScanPhase.FULL): + created, skipped, paths = self._run_fast_phase(roots) + total_created, skipped_existing, total_paths = created, skipped, paths + + if self._check_pause_and_cancel(): + cancelled = True + return + + self._emit_event( + "assets.seed.fast_complete", + { + "roots": list(roots), + "created": total_created, + "skipped": skipped_existing, + "total": total_paths, + }, + ) + + # Phase 2: Enrichment scan (metadata + hashes) + if phase in (ScanPhase.ENRICH, ScanPhase.FULL): + if self._check_pause_and_cancel(): + cancelled = True + return + + enrich_cancelled, total_enriched = self._run_enrich_phase(roots) + + if enrich_cancelled: + cancelled = True + return + + self._emit_event( + "assets.seed.enrich_complete", + { + "roots": list(roots), + "enriched": total_enriched, + }, + ) + + elapsed = time.perf_counter() - t_start + logging.info( + "Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d", + roots, + phase.value, + elapsed, + total_created, + total_enriched, + skipped_existing, + ) + + self._emit_event( + "assets.seed.completed", + { + "phase": phase.value, + "total": total_paths, + "created": total_created, + "enriched": total_enriched, + "skipped": skipped_existing, + "elapsed": round(elapsed, 3), + }, + ) + + except Exception as e: + self._add_error(f"Scan failed: {e}") + logging.exception("Asset scan failed") + self._emit_event("assets.seed.error", {"message": str(e)}) + finally: + if cancelled: + self._emit_event( + "assets.seed.cancelled", + { + "scanned": self._progress.scanned if self._progress else 0, + "total": total_paths, + "created": total_created, + }, + ) + with self._lock: + self._reset_to_idle() + pending = self._pending_enrich + if pending is not None: + self._pending_enrich = None + if not self.start_enrich( + roots=pending["roots"], + compute_hashes=pending["compute_hashes"], + ): + logging.warning( + "Pending enrich scan could not start (roots=%s)", + pending["roots"], + ) + + def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]: + """Run phase 1: fast scan to create stub records. + + Returns: + Tuple of (total_created, skipped_existing, total_paths) + """ + t_fast_start = time.perf_counter() + total_created = 0 + skipped_existing = 0 + + existing_paths: set[str] = set() + t_sync = time.perf_counter() + for r in roots: + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + existing_paths.update(sync_root_safely(r)) + logging.debug( + "Fast scan: sync_root phase took %.3fs (%d existing paths)", + time.perf_counter() - t_sync, + len(existing_paths), + ) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, 0 + + t_collect = time.perf_counter() + paths = collect_paths_for_roots(roots) + logging.debug( + "Fast scan: collect_paths took %.3fs (%d paths found)", + time.perf_counter() - t_collect, + len(paths), + ) + total_paths = len(paths) + self._update_progress(total=total_paths) + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "total": total_paths, "phase": "fast"}, + ) + + # Use stub specs (no metadata extraction, no hashing) + t_specs = time.perf_counter() + specs, tag_pool, skipped_existing = build_asset_specs( + paths, + existing_paths, + enable_metadata_extraction=False, + compute_hashes=False, + ) + logging.debug( + "Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)", + time.perf_counter() - t_specs, + len(specs), + skipped_existing, + ) + self._update_progress(skipped=skipped_existing) + + if self._check_pause_and_cancel(): + return total_created, skipped_existing, total_paths + + batch_size = 500 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + for i in range(0, len(specs), batch_size): + if self._check_pause_and_cancel(): + logging.info( + "Fast scan cancelled after %d/%d files (created=%d)", + i, + len(specs), + total_created, + ) + return total_created, skipped_existing, total_paths + + batch = specs[i : i + batch_size] + batch_tags = {t for spec in batch for t in spec["tags"]} + try: + created = insert_asset_specs(batch, batch_tags) + total_created += created + except Exception as e: + self._add_error(f"Batch insert failed at offset {i}: {e}") + logging.exception("Batch insert failed at offset %d", i) + + scanned = i + len(batch) + now = time.perf_counter() + self._update_progress(scanned=scanned, created=total_created) + + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "fast", + "scanned": scanned, + "total": len(specs), + "created": total_created, + }, + ) + last_progress_time = now + + self._update_progress(scanned=len(specs), created=total_created) + logging.info( + "Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)", + time.perf_counter() - t_fast_start, + total_created, + skipped_existing, + total_paths, + ) + return total_created, skipped_existing, total_paths + + def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]: + """Run phase 2: enrich existing records with metadata and hashes. + + Returns: + Tuple of (cancelled, total_enriched) + """ + total_enriched = 0 + batch_size = 100 + last_progress_time = time.perf_counter() + progress_interval = 1.0 + + # Get the target enrichment level based on compute_hashes + if not self._compute_hashes: + target_max_level = ENRICHMENT_STUB + else: + target_max_level = ENRICHMENT_METADATA + + self._emit_event( + "assets.seed.started", + {"roots": list(roots), "phase": "enrich"}, + ) + + skip_ids: set[str] = set() + consecutive_empty = 0 + max_consecutive_empty = 3 + + # Hash checkpoints survive across batches so interrupted hashes + # can be resumed without re-reading the entire file. + hash_checkpoints: dict[str, object] = {} + + while True: + if self._check_pause_and_cancel(): + logging.info("Enrich scan cancelled after %d assets", total_enriched) + return True, total_enriched + + # Fetch next batch of unenriched assets + unenriched = get_unenriched_assets_for_roots( + roots, + max_level=target_max_level, + limit=batch_size, + ) + + # Filter out previously failed references + if skip_ids: + unenriched = [r for r in unenriched if r.reference_id not in skip_ids] + + if not unenriched: + break + + enriched, failed_ids = enrich_assets_batch( + unenriched, + extract_metadata=True, + compute_hash=self._compute_hashes, + interrupt_check=self._is_paused_or_cancelled, + hash_checkpoints=hash_checkpoints, + ) + total_enriched += enriched + skip_ids.update(failed_ids) + + if enriched == 0: + consecutive_empty += 1 + if consecutive_empty >= max_consecutive_empty: + logging.warning( + "Enrich phase stopping: %d consecutive batches with no progress (%d skipped)", + consecutive_empty, + len(skip_ids), + ) + break + else: + consecutive_empty = 0 + + now = time.perf_counter() + if now - last_progress_time >= progress_interval: + self._emit_event( + "assets.seed.progress", + { + "phase": "enrich", + "enriched": total_enriched, + }, + ) + last_progress_time = now + + return False, total_enriched + + +asset_seeder = _AssetSeeder() diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py new file mode 100644 index 000000000..03990966b --- /dev/null +++ b/app/assets/services/__init__.py @@ -0,0 +1,91 @@ +from app.assets.services.asset_management import ( + asset_exists, + delete_asset_reference, + get_asset_by_hash, + get_asset_detail, + list_assets_page, + resolve_asset_for_download, + set_asset_preview, + update_asset_metadata, +) +from app.assets.services.bulk_ingest import ( + BulkInsertResult, + batch_insert_seed_assets, + cleanup_unreferenced_assets, +) +from app.assets.services.file_utils import ( + get_mtime_ns, + get_size_and_mtime_ns, + list_files_recursively, + verify_file_unchanged, +) +from app.assets.services.ingest import ( + DependencyMissingError, + HashMismatchError, + create_from_hash, + ingest_existing_file, + register_output_files, + upload_from_temp_path, +) +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, +) +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + IngestResult, + ListAssetsResult, + ReferenceData, + RegisterAssetResult, + TagUsage, + UploadResult, + UserMetadata, +) +from app.assets.services.tagging import ( + apply_tags, + list_tags, + remove_tags, +) + +__all__ = [ + "AddTagsResult", + "AssetData", + "AssetDetailResult", + "AssetSummaryData", + "ReferenceData", + "BulkInsertResult", + "DependencyMissingError", + "DownloadResolutionResult", + "HashMismatchError", + "IngestResult", + "ListAssetsResult", + "RegisterAssetResult", + "RemoveTagsResult", + "TagUsage", + "UploadResult", + "UserMetadata", + "apply_tags", + "asset_exists", + "batch_insert_seed_assets", + "create_from_hash", + "delete_asset_reference", + "get_asset_by_hash", + "get_asset_detail", + "ingest_existing_file", + "register_output_files", + "get_mtime_ns", + "get_size_and_mtime_ns", + "list_assets_page", + "list_files_recursively", + "list_tags", + "cleanup_unreferenced_assets", + "remove_tags", + "resolve_asset_for_download", + "set_asset_preview", + "update_asset_metadata", + "upload_from_temp_path", + "verify_file_unchanged", +] diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py new file mode 100644 index 000000000..5aefd9956 --- /dev/null +++ b/app/assets/services/asset_management.py @@ -0,0 +1,367 @@ +import contextlib +import mimetypes +import os +from typing import Sequence + + +from app.assets.database.models import Asset +from app.assets.database.queries import ( + asset_exists_by_hash, + reference_exists_for_asset_id, + delete_reference_by_id, + fetch_reference_and_asset, + soft_delete_reference_by_id, + fetch_reference_asset_and_tags, + get_asset_by_hash as queries_get_asset_by_hash, + get_reference_by_id, + get_reference_with_owner_check, + list_references_page, + list_all_file_paths_by_asset_id, + list_references_by_asset_id, + set_reference_metadata, + set_reference_preview, + set_reference_tags, + update_asset_hash_and_mime, + update_reference_access_time, + update_reference_name, + update_reference_updated_at, +) +from app.assets.helpers import select_best_live_path +from app.assets.services.path_utils import compute_relative_filename +from app.assets.services.schemas import ( + AssetData, + AssetDetailResult, + AssetSummaryData, + DownloadResolutionResult, + ListAssetsResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def get_asset_detail( + reference_id: str, + owner_id: str = "", +) -> AssetDetailResult | None: + with create_session() as session: + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + return None + + ref, asset, tags = result + return AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + + +def update_asset_metadata( + reference_id: str, + name: str | None = None, + tags: Sequence[str] | None = None, + user_metadata: UserMetadata = None, + tag_origin: str = "manual", + owner_id: str = "", + mime_type: str | None = None, + preview_id: str | None = None, +) -> AssetDetailResult: + with create_session() as session: + ref = get_reference_with_owner_check(session, reference_id, owner_id) + + touched = False + if name is not None and name != ref.name: + update_reference_name(session, reference_id=reference_id, name=name) + touched = True + + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + + new_meta: dict | None = None + if user_metadata is not None: + new_meta = dict(user_metadata) + elif computed_filename: + current_meta = ref.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + + if new_meta is not None: + if computed_filename: + new_meta["filename"] = computed_filename + set_reference_metadata( + session, reference_id=reference_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + set_reference_tags( + session, + reference_id=reference_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if mime_type is not None: + updated = update_asset_hash_and_mime( + session, asset_id=ref.asset_id, mime_type=mime_type + ) + if updated: + touched = True + + if preview_id is not None: + set_reference_preview( + session, + reference_id=reference_id, + preview_reference_id=preview_id, + ) + touched = True + + if touched and user_metadata is None: + update_reference_updated_at(session, reference_id=reference_id) + + result = fetch_reference_asset_and_tags( + session, + reference_id=reference_id, + owner_id=owner_id, + ) + if not result: + raise RuntimeError("State changed during update") + + ref, asset, tag_list = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_list, + ) + session.commit() + + return detail + + +def delete_asset_reference( + reference_id: str, + owner_id: str, + delete_content_if_orphan: bool = True, +) -> bool: + with create_session() as session: + if not delete_content_if_orphan: + # Soft delete: mark the reference as deleted but keep everything + deleted = soft_delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + session.commit() + return deleted + + ref_row = get_reference_by_id(session, reference_id=reference_id) + asset_id = ref_row.asset_id if ref_row else None + file_path = ref_row.file_path if ref_row else None + + deleted = delete_reference_by_id( + session, reference_id=reference_id, owner_id=owner_id + ) + if not deleted: + session.commit() + return False + + if not asset_id: + session.commit() + return True + + still_exists = reference_exists_for_asset_id(session, asset_id=asset_id) + if still_exists: + session.commit() + return True + + # Orphaned asset - gather ALL file paths (including + # soft-deleted / missing refs) so their on-disk files get cleaned up. + file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id) + # Also include the just-deleted file path + if file_path: + file_paths.append(file_path) + + asset_row = session.get(Asset, asset_id) + if asset_row is not None: + session.delete(asset_row) + + session.commit() + + # Delete files after commit + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + + return True + + +def set_asset_preview( + reference_id: str, + preview_reference_id: str | None = None, + owner_id: str = "", +) -> AssetDetailResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + set_reference_preview( + session, + reference_id=reference_id, + preview_reference_id=preview_reference_id, + ) + + result = fetch_reference_asset_and_tags( + session, reference_id=reference_id, owner_id=owner_id + ) + if not result: + raise RuntimeError("State changed during preview update") + + ref, asset, tags = result + detail = AssetDetailResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tags, + ) + session.commit() + + return detail + + +def asset_exists(asset_hash: str) -> bool: + with create_session() as session: + return asset_exists_by_hash(session, asset_hash=asset_hash) + + +def get_asset_by_hash(asset_hash: str) -> AssetData | None: + with create_session() as session: + asset = queries_get_asset_by_hash(session, asset_hash=asset_hash) + return extract_asset_data(asset) + + +def list_assets_page( + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 20, + offset: int = 0, + sort: str = "created_at", + order: str = "desc", +) -> ListAssetsResult: + with create_session() as session: + refs, tag_map, total = list_references_page( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + offset=offset, + sort=sort, + order=order, + ) + + items: list[AssetSummaryData] = [] + for ref in refs: + items.append( + AssetSummaryData( + ref=extract_reference_data(ref), + asset=extract_asset_data(ref.asset), + tags=tag_map.get(ref.id, []), + ) + ) + + return ListAssetsResult(items=items, total=total) + + +def resolve_hash_to_path( + asset_hash: str, + owner_id: str = "", +) -> DownloadResolutionResult | None: + """Resolve a blake3 hash to an on-disk file path. + + Only references visible to *owner_id* are considered (owner-less + references are always visible). + + Returns a DownloadResolutionResult with abs_path, content_type, and + download_name, or None if no asset or live path is found. + """ + with create_session() as session: + asset = queries_get_asset_by_hash(session, asset_hash) + if not asset: + return None + refs = list_references_by_asset_id(session, asset_id=asset.id) + visible = [ + r for r in refs + if r.owner_id == "" or r.owner_id == owner_id + ] + abs_path = select_best_live_path(visible) + if not abs_path: + return None + display_name = os.path.basename(abs_path) + for ref in visible: + if ref.file_path == abs_path and ref.name: + display_name = ref.name + break + ctype = ( + asset.mime_type + or mimetypes.guess_type(display_name)[0] + or "application/octet-stream" + ) + return DownloadResolutionResult( + abs_path=abs_path, + content_type=ctype, + download_name=display_name, + ) + + +def resolve_asset_for_download( + reference_id: str, + owner_id: str = "", +) -> DownloadResolutionResult: + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise ValueError(f"AssetReference {reference_id} not found") + + ref, asset = pair + + # For references with file_path, use that directly + if ref.file_path and os.path.isfile(ref.file_path): + abs_path = ref.file_path + else: + # For API-created refs without file_path, find a path from other refs + refs = list_references_by_asset_id(session, asset_id=asset.id) + abs_path = select_best_live_path(refs) + if not abs_path: + raise FileNotFoundError( + f"No live path for AssetReference {reference_id} " + f"(asset id={asset.id}, name={ref.name})" + ) + + # Capture ORM attributes before commit (commit expires loaded objects) + ref_name = ref.name + asset_mime = asset.mime_type + + update_reference_access_time(session, reference_id=reference_id) + session.commit() + + ctype = ( + asset_mime + or mimetypes.guess_type(ref_name or abs_path)[0] + or "application/octet-stream" + ) + download_name = ref_name or os.path.basename(abs_path) + return DownloadResolutionResult( + abs_path=abs_path, + content_type=ctype, + download_name=download_name, + ) diff --git a/app/assets/services/bulk_ingest.py b/app/assets/services/bulk_ingest.py new file mode 100644 index 000000000..67aad838f --- /dev/null +++ b/app/assets/services/bulk_ingest.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import os +import uuid +from dataclasses import dataclass +from datetime import datetime +from typing import TYPE_CHECKING, Any, TypedDict + +from sqlalchemy.orm import Session + +from app.assets.database.queries import ( + bulk_insert_assets, + bulk_insert_references_ignore_conflicts, + bulk_insert_tags_and_meta, + delete_assets_by_ids, + get_existing_asset_ids, + get_reference_ids_by_ids, + get_references_by_paths_and_asset_ids, + get_unreferenced_unhashed_asset_ids, + restore_references_by_paths, +) +from app.assets.helpers import get_utc_now + +if TYPE_CHECKING: + from app.assets.services.metadata_extract import ExtractedMetadata + + +class SeedAssetSpec(TypedDict): + """Spec for seeding an asset from filesystem.""" + + abs_path: str + size_bytes: int + mtime_ns: int + info_name: str + tags: list[str] + fname: str + metadata: ExtractedMetadata | None + hash: str | None + mime_type: str | None + job_id: str | None + + +class AssetRow(TypedDict): + """Row data for inserting an Asset.""" + + id: str + hash: str | None + size_bytes: int + mime_type: str | None + created_at: datetime + + +class ReferenceRow(TypedDict): + """Row data for inserting an AssetReference.""" + + id: str + asset_id: str + file_path: str + mtime_ns: int + owner_id: str + name: str + preview_id: str | None + user_metadata: dict[str, Any] | None + job_id: str | None + created_at: datetime + updated_at: datetime + last_access_time: datetime + + +class TagRow(TypedDict): + """Row data for inserting a Tag.""" + + asset_reference_id: str + tag_name: str + origin: str + added_at: datetime + + +class MetadataRow(TypedDict): + """Row data for inserting asset metadata.""" + + asset_reference_id: str + key: str + ordinal: int + val_str: str | None + val_num: float | None + val_bool: bool | None + val_json: dict[str, Any] | None + + +@dataclass +class BulkInsertResult: + """Result of bulk asset insertion.""" + + inserted_refs: int + won_paths: int + lost_paths: int + + +def batch_insert_seed_assets( + session: Session, + specs: list[SeedAssetSpec], + owner_id: str = "", +) -> BulkInsertResult: + """Seed assets from filesystem specs in batch. + + Each spec is a dict with keys: + - abs_path: str + - size_bytes: int + - mtime_ns: int + - info_name: str + - tags: list[str] + - fname: Optional[str] + + This function orchestrates: + 1. Insert seed Assets (hash=NULL) + 2. Claim references with ON CONFLICT DO NOTHING on file_path + 3. Query to find winners (paths where our asset_id was inserted) + 4. Delete Assets for losers (path already claimed by another asset) + 5. Insert tags and metadata for successfully inserted references + + Returns: + BulkInsertResult with inserted_refs, won_paths, lost_paths + """ + if not specs: + return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0) + + current_time = get_utc_now() + asset_rows: list[AssetRow] = [] + reference_rows: list[ReferenceRow] = [] + path_to_asset_id: dict[str, str] = {} + asset_id_to_ref_data: dict[str, dict] = {} + absolute_path_list: list[str] = [] + + for spec in specs: + absolute_path = os.path.abspath(spec["abs_path"]) + asset_id = str(uuid.uuid4()) + reference_id = str(uuid.uuid4()) + absolute_path_list.append(absolute_path) + path_to_asset_id[absolute_path] = asset_id + + mime_type = spec.get("mime_type") + asset_rows.append( + { + "id": asset_id, + "hash": spec.get("hash"), + "size_bytes": spec["size_bytes"], + "mime_type": mime_type, + "created_at": current_time, + } + ) + + # Build user_metadata from extracted metadata or fallback to filename + extracted_metadata = spec.get("metadata") + if extracted_metadata: + user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata() + elif spec["fname"]: + user_metadata = {"filename": spec["fname"]} + else: + user_metadata = None + + reference_rows.append( + { + "id": reference_id, + "asset_id": asset_id, + "file_path": absolute_path, + "mtime_ns": spec["mtime_ns"], + "owner_id": owner_id, + "name": spec["info_name"], + "preview_id": None, + "user_metadata": user_metadata, + "job_id": spec.get("job_id"), + "created_at": current_time, + "updated_at": current_time, + "last_access_time": current_time, + } + ) + + asset_id_to_ref_data[asset_id] = { + "reference_id": reference_id, + "tags": spec["tags"], + "filename": spec["fname"], + "extracted_metadata": extracted_metadata, + } + + bulk_insert_assets(session, asset_rows) + + # Filter reference rows to only those whose assets were actually inserted + # (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING) + inserted_asset_ids = get_existing_asset_ids( + session, [r["asset_id"] for r in reference_rows] + ) + reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids] + + bulk_insert_references_ignore_conflicts(session, reference_rows) + restore_references_by_paths(session, absolute_path_list) + winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id) + + inserted_paths = { + path + for path in absolute_path_list + if path_to_asset_id[path] in inserted_asset_ids + } + losing_paths = inserted_paths - winning_paths + lost_asset_ids = [path_to_asset_id[path] for path in losing_paths] + + if lost_asset_ids: + delete_assets_by_ids(session, lost_asset_ids) + + if not winning_paths: + return BulkInsertResult( + inserted_refs=0, + won_paths=0, + lost_paths=len(losing_paths), + ) + + # Get reference IDs for winners + winning_ref_ids = [ + asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"] + for path in winning_paths + ] + inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids) + + tag_rows: list[TagRow] = [] + metadata_rows: list[MetadataRow] = [] + + if inserted_ref_ids: + for path in winning_paths: + asset_id = path_to_asset_id[path] + ref_data = asset_id_to_ref_data[asset_id] + ref_id = ref_data["reference_id"] + + if ref_id not in inserted_ref_ids: + continue + + for tag in ref_data["tags"]: + tag_rows.append( + { + "asset_reference_id": ref_id, + "tag_name": tag, + "origin": "automatic", + "added_at": current_time, + } + ) + + # Use extracted metadata for meta rows if available + extracted_metadata = ref_data.get("extracted_metadata") + if extracted_metadata: + metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id)) + elif ref_data["filename"]: + # Fallback: just store filename + metadata_rows.append( + { + "asset_reference_id": ref_id, + "key": "filename", + "ordinal": 0, + "val_str": ref_data["filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) + + bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows) + + return BulkInsertResult( + inserted_refs=len(inserted_ref_ids), + won_paths=len(winning_paths), + lost_paths=len(losing_paths), + ) + + +def cleanup_unreferenced_assets(session: Session) -> int: + """Hard-delete unhashed assets with no active references. + + This is a destructive operation intended for explicit cleanup. + Only deletes assets where hash=None and all references are missing. + + Returns: + Number of assets deleted + """ + unreferenced_ids = get_unreferenced_unhashed_asset_ids(session) + return delete_assets_by_ids(session, unreferenced_ids) diff --git a/app/assets/services/file_utils.py b/app/assets/services/file_utils.py new file mode 100644 index 000000000..c47ebe460 --- /dev/null +++ b/app/assets/services/file_utils.py @@ -0,0 +1,70 @@ +import os + + +def get_mtime_ns(stat_result: os.stat_result) -> int: + """Extract mtime in nanoseconds from a stat result.""" + return getattr( + stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000) + ) + + +def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]: + """Get file size in bytes and mtime in nanoseconds.""" + st = os.stat(path, follow_symlinks=follow_symlinks) + return st.st_size, get_mtime_ns(st) + + +def verify_file_unchanged( + mtime_db: int | None, + size_db: int | None, + stat_result: os.stat_result, +) -> bool: + """Check if a file is unchanged based on mtime and size. + + Returns True if the file's mtime and size match the database values. + Returns False if mtime_db is None or values don't match. + + size_db=None means don't check size; 0 is a valid recorded size. + """ + if mtime_db is None: + return False + actual_mtime_ns = get_mtime_ns(stat_result) + if int(mtime_db) != int(actual_mtime_ns): + return False + if size_db is not None: + return int(stat_result.st_size) == int(size_db) + return True + + +def is_visible(name: str) -> bool: + """Return True if a file or directory name is visible (not hidden).""" + return not name.startswith(".") + + +def list_files_recursively(base_dir: str) -> list[str]: + """Recursively list all files in a directory, following symlinks.""" + out: list[str] = [] + base_abs = os.path.abspath(base_dir) + if not os.path.isdir(base_abs): + return out + # Track seen real directory identities to prevent circular symlink loops + seen_dirs: set[tuple[int, int]] = set() + for dirpath, subdirs, filenames in os.walk( + base_abs, topdown=True, followlinks=True + ): + try: + st = os.stat(dirpath) + dir_id = (st.st_dev, st.st_ino) + except OSError: + subdirs.clear() + continue + if dir_id in seen_dirs: + subdirs.clear() + continue + seen_dirs.add(dir_id) + subdirs[:] = [d for d in subdirs if is_visible(d)] + for name in filenames: + if not is_visible(name): + continue + out.append(os.path.abspath(os.path.join(dirpath, name))) + return out diff --git a/app/assets/services/hashing.py b/app/assets/services/hashing.py new file mode 100644 index 000000000..41d8b4615 --- /dev/null +++ b/app/assets/services/hashing.py @@ -0,0 +1,99 @@ +import io +import os +from contextlib import contextmanager +from dataclasses import dataclass +from typing import IO, Any, Callable, Iterator +import logging + +try: + from blake3 import blake3 +except ModuleNotFoundError: + logging.warning("WARNING: blake3 package not installed") + +DEFAULT_CHUNK = 8 * 1024 * 1024 + +InterruptCheck = Callable[[], bool] + + +@dataclass +class HashCheckpoint: + """Saved state for resuming an interrupted hash computation.""" + + bytes_processed: int + hasher: Any # blake3 hasher instance + mtime_ns: int = 0 + file_size: int = 0 + + +@contextmanager +def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]: + """Yield (file_object, is_path) with appropriate setup/teardown.""" + if hasattr(fp, "read"): + seekable = getattr(fp, "seekable", lambda: False)() + orig_pos = None + if seekable: + try: + orig_pos = fp.tell() + if orig_pos != 0: + fp.seek(0) + except io.UnsupportedOperation: + orig_pos = None + try: + yield fp, False + finally: + if orig_pos is not None: + fp.seek(orig_pos) + else: + with open(os.fspath(fp), "rb") as f: + yield f, True + + +def compute_blake3_hash( + fp: str | IO[bytes], + chunk_size: int = DEFAULT_CHUNK, + interrupt_check: InterruptCheck | None = None, + checkpoint: HashCheckpoint | None = None, +) -> tuple[str | None, HashCheckpoint | None]: + """Compute BLAKE3 hash of a file, with optional checkpoint support. + + Args: + fp: File path or file-like object + chunk_size: Size of chunks to read at a time + interrupt_check: Optional callable that returns True if the operation + should be interrupted (e.g. paused or cancelled). Must be + non-blocking so file handles are released immediately. Checked + between chunk reads. + checkpoint: Optional checkpoint to resume from (file paths only) + + Returns: + Tuple of (hex_digest, None) on completion, or + (None, checkpoint) on interruption (file paths only), or + (None, None) on interruption of a file object + """ + if chunk_size <= 0: + chunk_size = DEFAULT_CHUNK + + with _open_for_hashing(fp) as (f, is_path): + if checkpoint is not None and is_path: + f.seek(checkpoint.bytes_processed) + h = checkpoint.hasher + bytes_processed = checkpoint.bytes_processed + else: + h = blake3() + bytes_processed = 0 + + while True: + if interrupt_check is not None and interrupt_check(): + if is_path: + return None, HashCheckpoint( + bytes_processed=bytes_processed, + hasher=h, + ) + return None, None + chunk = f.read(chunk_size) + if not chunk: + break + h.update(chunk) + bytes_processed += len(chunk) + + return h.hexdigest(), None diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py new file mode 100644 index 000000000..f0b070517 --- /dev/null +++ b/app/assets/services/ingest.py @@ -0,0 +1,563 @@ +import contextlib +import logging +import mimetypes +import os +from typing import Any, Sequence + +from sqlalchemy.orm import Session + +import app.assets.services.hashing as hashing +from app.assets.database.queries import ( + add_tags_to_reference, + count_active_siblings, + create_stub_asset, + ensure_tags_exist, + fetch_reference_and_asset, + get_asset_by_hash, + get_reference_by_file_path, + get_reference_tags, + get_or_create_reference, + reference_exists, + remove_missing_tag_for_asset_id, + set_reference_metadata, + set_reference_tags, + update_asset_hash_and_mime, + upsert_asset, + upsert_reference, + validate_tags_exist, +) +from app.assets.helpers import get_utc_now, normalize_tags +from app.assets.services.bulk_ingest import batch_insert_seed_assets +from app.assets.services.file_utils import get_size_and_mtime_ns +from app.assets.services.path_utils import ( + compute_relative_filename, + get_name_and_tags_from_asset_path, + resolve_destination_from_tags, + validate_path_within_base, +) +from app.assets.services.schemas import ( + IngestResult, + RegisterAssetResult, + UploadResult, + UserMetadata, + extract_asset_data, + extract_reference_data, +) +from app.database.db import create_session + + +def _ingest_file_from_path( + abs_path: str, + asset_hash: str, + size_bytes: int, + mtime_ns: int, + mime_type: str | None = None, + info_name: str | None = None, + owner_id: str = "", + preview_id: str | None = None, + user_metadata: UserMetadata = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> IngestResult: + locator = os.path.abspath(abs_path) + user_metadata = user_metadata or {} + + asset_created = False + asset_updated = False + ref_created = False + ref_updated = False + reference_id: str | None = None + + with create_session() as session: + if preview_id: + if not reference_exists(session, preview_id): + preview_id = None + + asset, asset_created, asset_updated = upsert_asset( + session, + asset_hash=asset_hash, + size_bytes=size_bytes, + mime_type=mime_type, + ) + + ref_created, ref_updated = upsert_reference( + session, + asset_id=asset.id, + file_path=locator, + name=info_name or os.path.basename(locator), + mtime_ns=mtime_ns, + owner_id=owner_id, + ) + + # Get the reference we just created/updated + ref = get_reference_by_file_path(session, locator) + if ref: + reference_id = ref.id + + if preview_id and ref.preview_id != preview_id: + ref.preview_id = preview_id + + norm = normalize_tags(list(tags)) + if norm: + if require_existing_tags: + validate_tags_exist(session, norm) + add_tags_to_reference( + session, + reference_id=reference_id, + tags=norm, + origin=tag_origin, + create_if_missing=not require_existing_tags, + ) + + _update_metadata_with_filename( + session, + reference_id=reference_id, + file_path=ref.file_path, + current_metadata=ref.user_metadata, + user_metadata=user_metadata, + ) + + try: + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + + session.commit() + + return IngestResult( + asset_created=asset_created, + asset_updated=asset_updated, + ref_created=ref_created, + ref_updated=ref_updated, + reference_id=reference_id, + ) + + +def register_output_files( + file_paths: Sequence[str], + user_metadata: UserMetadata = None, + job_id: str | None = None, +) -> int: + """Register a batch of output file paths as assets. + + Returns the number of files successfully registered. + """ + registered = 0 + for abs_path in file_paths: + if not os.path.isfile(abs_path): + continue + try: + if ingest_existing_file( + abs_path, user_metadata=user_metadata, job_id=job_id + ): + registered += 1 + except Exception: + logging.exception("Failed to register output: %s", abs_path) + return registered + + +def ingest_existing_file( + abs_path: str, + user_metadata: UserMetadata = None, + extra_tags: Sequence[str] = (), + owner_id: str = "", + job_id: str | None = None, +) -> bool: + """Register an existing on-disk file as an asset stub. + + If a reference already exists for this path, updates mtime_ns, job_id, + size_bytes, and resets enrichment so the enricher will re-hash it. + + For brand-new paths, inserts a stub record (hash=NULL) for immediate + UX visibility. + + Returns True if a row was inserted or updated, False otherwise. + """ + locator = os.path.abspath(abs_path) + size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path) + mime_type = mimetypes.guess_type(abs_path, strict=False)[0] + name, path_tags = get_name_and_tags_from_asset_path(abs_path) + tags = list(dict.fromkeys(path_tags + list(extra_tags))) + + with create_session() as session: + existing_ref = get_reference_by_file_path(session, locator) + if existing_ref is not None: + now = get_utc_now() + existing_ref.mtime_ns = mtime_ns + existing_ref.job_id = job_id + existing_ref.is_missing = False + existing_ref.deleted_at = None + existing_ref.updated_at = now + existing_ref.enrichment_level = 0 + + asset = existing_ref.asset + if asset: + # If other refs share this asset, detach to a new stub + # instead of mutating the shared row. + siblings = count_active_siblings(session, asset.id, existing_ref.id) + if siblings > 0: + new_asset = create_stub_asset( + session, + size_bytes=size_bytes, + mime_type=mime_type or asset.mime_type, + ) + existing_ref.asset_id = new_asset.id + else: + asset.hash = None + asset.size_bytes = size_bytes + if mime_type: + asset.mime_type = mime_type + session.commit() + return True + + spec = { + "abs_path": abs_path, + "size_bytes": size_bytes, + "mtime_ns": mtime_ns, + "info_name": name, + "tags": tags, + "fname": os.path.basename(abs_path), + "metadata": None, + "hash": None, + "mime_type": mime_type, + "job_id": job_id, + } + if tags: + ensure_tags_exist(session, tags) + result = batch_insert_seed_assets(session, [spec], owner_id=owner_id) + session.commit() + return result.won_paths > 0 + + +def _register_existing_asset( + asset_hash: str, + name: str, + user_metadata: UserMetadata = None, + tags: list[str] | None = None, + tag_origin: str = "manual", + owner_id: str = "", + mime_type: str | None = None, + preview_id: str | None = None, +) -> RegisterAssetResult: + user_metadata = user_metadata or {} + + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"No asset with hash {asset_hash}") + + if mime_type and not asset.mime_type: + update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type) + + if preview_id: + if not reference_exists(session, preview_id): + preview_id = None + + ref, ref_created = get_or_create_reference( + session, + asset_id=asset.id, + owner_id=owner_id, + name=name, + preview_id=preview_id, + ) + + if not ref_created: + if preview_id and ref.preview_id != preview_id: + ref.preview_id = preview_id + + tag_names = get_reference_tags(session, reference_id=ref.id) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=False, + ) + session.commit() + return result + + new_meta = dict(user_metadata) + computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta: + set_reference_metadata( + session, + reference_id=ref.id, + user_metadata=new_meta, + ) + + if tags is not None: + set_reference_tags( + session, + reference_id=ref.id, + tags=tags, + origin=tag_origin, + ) + + tag_names = get_reference_tags(session, reference_id=ref.id) + session.refresh(ref) + result = RegisterAssetResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created=True, + ) + session.commit() + + return result + + + +def _update_metadata_with_filename( + session: Session, + reference_id: str, + file_path: str | None, + current_metadata: dict | None, + user_metadata: dict[str, Any], +) -> None: + computed_filename = compute_relative_filename(file_path) if file_path else None + + current_meta = current_metadata or {} + new_meta = dict(current_meta) + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + set_reference_metadata( + session, + reference_id=reference_id, + user_metadata=new_meta, + ) + + +def _sanitize_filename(name: str | None, fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + return n if n else fallback + + +class HashMismatchError(Exception): + pass + + +class DependencyMissingError(Exception): + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +def upload_from_temp_path( + temp_path: str, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, + client_filename: str | None = None, + owner_id: str = "", + expected_hash: str | None = None, + mime_type: str | None = None, + preview_id: str | None = None, +) -> UploadResult: + try: + digest, _ = hashing.compute_blake3_hash(temp_path) + except ImportError as e: + raise DependencyMissingError(str(e)) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + if expected_hash and asset_hash != expected_hash.strip().lower(): + raise HashMismatchError("Uploaded file hash does not match provided hash.") + + with create_session() as session: + existing = get_asset_by_hash(session, asset_hash=asset_hash) + + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + display_name = _sanitize_filename(name or client_filename, fallback=digest) + result = _register_existing_asset( + asset_hash=asset_hash, + name=display_name, + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + mime_type=mime_type, + preview_id=preview_id, + ) + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) + + if not tags: + raise ValueError("tags are required for new asset uploads") + base_dir, subdirs = resolve_destination_from_tags(tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + src_for_ext = (client_filename or name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) + validate_path_within_base(dest_abs, base_dir) + + content_type = mime_type or ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) + + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + try: + size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + ingest_result = _ingest_file_from_path( + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_sanitize_filename(name or client_filename, fallback=digest), + owner_id=owner_id, + preview_id=preview_id, + user_metadata=user_metadata or {}, + tags=tags, + tag_origin="manual", + require_existing_tags=False, + ) + reference_id = ingest_result.reference_id + if not reference_id: + raise RuntimeError("failed to create asset reference") + + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + ref, asset = pair + tag_names = get_reference_tags(session, reference_id=ref.id) + + return UploadResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created_new=ingest_result.asset_created, + ) + + +def register_file_in_place( + abs_path: str, + name: str, + tags: list[str], + owner_id: str = "", + mime_type: str | None = None, +) -> UploadResult: + """Register an already-saved file in the asset database without moving it. + + Tags are derived from the filesystem path (root category + subfolder names), + merged with any caller-provided tags, matching the behavior of the scanner. + If the path is not under a known root, only the caller-provided tags are used. + """ + try: + _, path_tags = get_name_and_tags_from_asset_path(abs_path) + except ValueError: + path_tags = [] + merged_tags = normalize_tags([*path_tags, *tags]) + + try: + digest, _ = hashing.compute_blake3_hash(abs_path) + except ImportError as e: + raise DependencyMissingError(str(e)) + except Exception as e: + raise RuntimeError(f"failed to hash file: {e}") + asset_hash = "blake3:" + digest + + size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path) + content_type = mime_type or ( + mimetypes.guess_type(abs_path, strict=False)[0] + or "application/octet-stream" + ) + + ingest_result = _ingest_file_from_path( + abs_path=abs_path, + asset_hash=asset_hash, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_sanitize_filename(name, fallback=digest), + owner_id=owner_id, + tags=merged_tags, + tag_origin="upload", + require_existing_tags=False, + ) + reference_id = ingest_result.reference_id + if not reference_id: + raise RuntimeError("failed to create asset reference") + + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + ref, asset = pair + tag_names = get_reference_tags(session, reference_id=ref.id) + + return UploadResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created_new=ingest_result.asset_created, + ) + + +def create_from_hash( + hash_str: str, + name: str, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", + mime_type: str | None = None, + preview_id: str | None = None, +) -> UploadResult | None: + canonical = hash_str.strip().lower() + + try: + result = _register_existing_asset( + asset_hash=canonical, + name=_sanitize_filename( + name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical + ), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + mime_type=mime_type, + preview_id=preview_id, + ) + except ValueError: + logging.warning("create_from_hash: no asset found for hash %s", canonical) + return None + + return UploadResult( + ref=result.ref, + asset=result.asset, + tags=result.tags, + created_new=False, + ) diff --git a/app/assets/services/metadata_extract.py b/app/assets/services/metadata_extract.py new file mode 100644 index 000000000..a004929bc --- /dev/null +++ b/app/assets/services/metadata_extract.py @@ -0,0 +1,327 @@ +"""Metadata extraction for asset scanning. + +Tier 1: Filesystem metadata (zero parsing) +Tier 2: Safetensors header metadata (fast JSON read only) +""" + +from __future__ import annotations + +import json +import logging +import mimetypes +import os +import struct +from dataclasses import dataclass +from typing import Any + +from utils.mime_types import init_mime_types + +init_mime_types() + +# Supported safetensors extensions +SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"}) + +# Maximum safetensors header size to read (8MB) +MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024 + + +@dataclass +class ExtractedMetadata: + """Metadata extracted from a file during scanning.""" + + # Tier 1: Filesystem (always available) + filename: str = "" + file_path: str = "" # Full absolute path to the file + content_length: int = 0 + content_type: str | None = None + format: str = "" # file extension without dot + + # Tier 2: Safetensors header (if available) + base_model: str | None = None + trained_words: list[str] | None = None + air: str | None = None # CivitAI AIR identifier + has_preview_images: bool = False + + # Source provenance (populated if embedded in safetensors) + source_url: str | None = None + source_arn: str | None = None + repo_url: str | None = None + preview_url: str | None = None + source_hash: str | None = None + + # HuggingFace specific + repo_id: str | None = None + revision: str | None = None + filepath: str | None = None + resolve_url: str | None = None + + def to_user_metadata(self) -> dict[str, Any]: + """Convert to user_metadata dict for AssetReference.user_metadata JSON field.""" + data: dict[str, Any] = { + "filename": self.filename, + "content_length": self.content_length, + "format": self.format, + } + if self.file_path: + data["file_path"] = self.file_path + if self.content_type: + data["content_type"] = self.content_type + + # Tier 2 fields + if self.base_model: + data["base_model"] = self.base_model + if self.trained_words: + data["trained_words"] = self.trained_words + if self.air: + data["air"] = self.air + if self.has_preview_images: + data["has_preview_images"] = True + + # Source provenance + if self.source_url: + data["source_url"] = self.source_url + if self.source_arn: + data["source_arn"] = self.source_arn + if self.repo_url: + data["repo_url"] = self.repo_url + if self.preview_url: + data["preview_url"] = self.preview_url + if self.source_hash: + data["source_hash"] = self.source_hash + + # HuggingFace + if self.repo_id: + data["repo_id"] = self.repo_id + if self.revision: + data["revision"] = self.revision + if self.filepath: + data["filepath"] = self.filepath + if self.resolve_url: + data["resolve_url"] = self.resolve_url + + return data + + def to_meta_rows(self, reference_id: str) -> list[dict]: + """Convert to asset_reference_meta rows for typed/indexed querying.""" + rows: list[dict] = [] + + def add_str(key: str, val: str | None, ordinal: int = 0) -> None: + if val: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": ordinal, + "val_str": val[:2048] if len(val) > 2048 else val, + "val_num": None, + "val_bool": None, + "val_json": None, + }) + + def add_num(key: str, val: int | float | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": val, + "val_bool": None, + "val_json": None, + }) + + def add_bool(key: str, val: bool | None) -> None: + if val is not None: + rows.append({ + "asset_reference_id": reference_id, + "key": key, + "ordinal": 0, + "val_str": None, + "val_num": None, + "val_bool": val, + "val_json": None, + }) + + # Tier 1 + add_str("filename", self.filename) + add_num("content_length", self.content_length) + add_str("content_type", self.content_type) + add_str("format", self.format) + + # Tier 2 + add_str("base_model", self.base_model) + add_str("air", self.air) + has_previews = self.has_preview_images if self.has_preview_images else None + add_bool("has_preview_images", has_previews) + + # trained_words as multiple rows with ordinals + if self.trained_words: + for i, word in enumerate(self.trained_words[:100]): # limit to 100 words + add_str("trained_words", word, ordinal=i) + + # Source provenance + add_str("source_url", self.source_url) + add_str("source_arn", self.source_arn) + add_str("repo_url", self.repo_url) + add_str("preview_url", self.preview_url) + add_str("source_hash", self.source_hash) + + # HuggingFace + add_str("repo_id", self.repo_id) + add_str("revision", self.revision) + add_str("filepath", self.filepath) + add_str("resolve_url", self.resolve_url) + + return rows + + +def _read_safetensors_header( + path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE +) -> dict[str, Any] | None: + """Read only the JSON header from a safetensors file. + + This is very fast - reads 8 bytes for header length, then the JSON header. + No tensor data is loaded. + + Args: + path: Absolute path to safetensors file + max_size: Maximum header size to read (default 8MB) + + Returns: + Parsed header dict or None if failed + """ + try: + with open(path, "rb") as f: + header_bytes = f.read(8) + if len(header_bytes) < 8: + return None + length_of_header = struct.unpack(" max_size: + return None + header_data = f.read(length_of_header) + if len(header_data) < length_of_header: + return None + return json.loads(header_data.decode("utf-8")) + except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error): + return None + + +def _extract_safetensors_metadata( + header: dict[str, Any], meta: ExtractedMetadata +) -> None: + """Extract metadata from safetensors header __metadata__ section. + + Modifies meta in-place. + """ + st_meta = header.get("__metadata__", {}) + if not isinstance(st_meta, dict): + return + + # Common model metadata + meta.base_model = ( + st_meta.get("ss_base_model_version") + or st_meta.get("modelspec.base_model") + or st_meta.get("base_model") + ) + + # Trained words / trigger words + trained_words = st_meta.get("ss_tag_frequency") + if trained_words and isinstance(trained_words, str): + try: + tag_freq = json.loads(trained_words) + # Extract unique tags from all datasets + all_tags: set[str] = set() + for dataset_tags in tag_freq.values(): + if isinstance(dataset_tags, dict): + all_tags.update(dataset_tags.keys()) + if all_tags: + meta.trained_words = sorted(all_tags)[:100] + except json.JSONDecodeError: + pass + + # Direct trained_words field (some formats) + if not meta.trained_words: + tw = st_meta.get("trained_words") + if isinstance(tw, str): + try: + parsed = json.loads(tw) + if isinstance(parsed, list): + meta.trained_words = [str(x) for x in parsed] + else: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + except json.JSONDecodeError: + meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()] + elif isinstance(tw, list): + meta.trained_words = [str(x) for x in tw] + + # CivitAI AIR + meta.air = st_meta.get("air") or st_meta.get("modelspec.air") + + # Preview images (ssmd_cover_images) + cover_images = st_meta.get("ssmd_cover_images") + if cover_images: + meta.has_preview_images = True + + # Source provenance fields + meta.source_url = st_meta.get("source_url") + meta.source_arn = st_meta.get("source_arn") + meta.repo_url = st_meta.get("repo_url") + meta.preview_url = st_meta.get("preview_url") + meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash") + + # HuggingFace fields + meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id") + meta.revision = st_meta.get("revision") or st_meta.get("hf_revision") + meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath") + meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url") + + +def extract_file_metadata( + abs_path: str, + stat_result: os.stat_result | None = None, + relative_filename: str | None = None, +) -> ExtractedMetadata: + """Extract metadata from a file using tier 1 and tier 2 methods. + + Tier 1: Filesystem metadata from path and stat + Tier 2: Safetensors header parsing if applicable + + Args: + abs_path: Absolute path to the file + stat_result: Optional pre-fetched stat result (saves a syscall) + relative_filename: Optional relative filename to use instead of basename + (e.g., "flux/123/model.safetensors" for model paths) + + Returns: + ExtractedMetadata with all available fields populated + """ + meta = ExtractedMetadata() + + # Tier 1: Filesystem metadata + meta.filename = relative_filename or os.path.basename(abs_path) + meta.file_path = abs_path + _, ext = os.path.splitext(abs_path) + meta.format = ext.lstrip(".").lower() if ext else "" + + mime_type, _ = mimetypes.guess_type(abs_path) + meta.content_type = mime_type + + # Size from stat + if stat_result is None: + try: + stat_result = os.stat(abs_path, follow_symlinks=True) + except OSError: + pass + + if stat_result: + meta.content_length = stat_result.st_size + + # Tier 2: Safetensors header (if applicable and enabled) + if ext.lower() in SAFETENSORS_EXTENSIONS: + header = _read_safetensors_header(abs_path) + if header: + try: + _extract_safetensors_metadata(header, meta) + except Exception as e: + logging.debug("Safetensors meta extract failed %s: %s", abs_path, e) + + return meta diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py new file mode 100644 index 000000000..f5dd7f7fd --- /dev/null +++ b/app/assets/services/path_utils.py @@ -0,0 +1,167 @@ +import os +from pathlib import Path +from typing import Literal + +import folder_paths +from app.assets.helpers import normalize_tags + + +_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"}) + + +def get_comfy_models_folders() -> list[tuple[str, list[str]]]: + """Build list of (folder_name, base_paths[]) for all model locations. + + Includes every category registered in folder_names_and_paths, + regardless of whether its paths are under the main models_dir, + but excludes non-model entries like custom_nodes. + """ + targets: list[tuple[str, list[str]]] = [] + for name, values in folder_paths.folder_names_and_paths.items(): + if name in _NON_MODEL_FOLDER_NAMES: + continue + paths, _exts = values[0], values[1] + if paths: + targets.append((name, paths)) + return targets + + +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + if not tags: + raise ValueError("tags must not be empty") + root = tags[0].lower() + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + elif root == "input": + base_dir = os.path.abspath(folder_paths.get_input_directory()) + raw_subdirs = tags[1:] + elif root == "output": + base_dir = os.path.abspath(folder_paths.get_output_directory()) + raw_subdirs = tags[1:] + else: + raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'") + _sep_chars = frozenset(("/", "\\", os.sep)) + for i in raw_subdirs: + if i in (".", "..") or _sep_chars & set(i): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + + +def validate_path_within_base(candidate: str, base: str) -> None: + cand_abs = Path(os.path.abspath(candidate)) + base_abs = Path(os.path.abspath(base)) + if not cand_abs.is_relative_to(base_abs): + raise ValueError("destination escapes base directory") + + +def compute_relative_filename(file_path: str) -> str | None: + """ + Return the model's path relative to the last well-known folder (the model category), + using forward slashes, eg: + /.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors" + /.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors" + + For non-model paths, returns None. + """ + try: + root_category, rel_path = get_asset_category_and_relative_path(file_path) + except ValueError: + return None + + p = Path(rel_path) + parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)] + if not parts: + return None + + if root_category == "models": + # parts[0] is the category ("checkpoints", "vae", etc) – drop it + inside = parts[1:] if len(parts) > 1 else [parts[0]] + return "/".join(inside) + return "/".join(parts) # input/output: keep all parts + + +def get_asset_category_and_relative_path( + file_path: str, +) -> tuple[Literal["input", "output", "models"], str]: + """Determine which root category a file path belongs to. + + Categories: + - 'input': under folder_paths.get_input_directory() + - 'output': under folder_paths.get_output_directory() + - 'models': under any base path from get_comfy_models_folders() + + Returns: + (root_category, relative_path_inside_that_root) + + Raises: + ValueError: path does not belong to any known root. + """ + fp_abs = os.path.abspath(file_path) + + def _check_is_within(child: str, parent: str) -> bool: + return Path(child).is_relative_to(parent) + + def _compute_relative(child: str, parent: str) -> str: + # Normalize relative path, stripping any leading ".." components + # by anchoring to root (os.sep) then computing relpath back from it. + return os.path.relpath( + os.path.join(os.sep, os.path.relpath(child, parent)), os.sep + ) + + # 1) input + input_base = os.path.abspath(folder_paths.get_input_directory()) + if _check_is_within(fp_abs, input_base): + return "input", _compute_relative(fp_abs, input_base) + + # 2) output + output_base = os.path.abspath(folder_paths.get_output_directory()) + if _check_is_within(fp_abs, output_base): + return "output", _compute_relative(fp_abs, output_base) + + # 3) models (check deepest matching base to avoid ambiguity) + best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket) + for bucket, bases in get_comfy_models_folders(): + for b in bases: + base_abs = os.path.abspath(b) + if not _check_is_within(fp_abs, base_abs): + continue + cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs)) + if best is None or cand[0] > best[0]: + best = cand + + if best is not None: + _, bucket, rel_inside = best + combined = os.path.join(bucket, rel_inside) + return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) + + raise ValueError( + f"Path is not within input, output, or configured model bases: {file_path}" + ) + + +def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: + """Return (name, tags) derived from a filesystem path. + + - name: base filename with extension + - tags: [root_category] + parent folder names in order + + Raises: + ValueError: path does not belong to any known root. + """ + root_category, some_path = get_asset_category_and_relative_path(file_path) + p = Path(some_path) + parent_parts = [ + part for part in p.parent.parts if part not in (".", "..", p.anchor) + ] + return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts]))) diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py new file mode 100644 index 000000000..0eb128f58 --- /dev/null +++ b/app/assets/services/schemas.py @@ -0,0 +1,113 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Any, NamedTuple + +from app.assets.database.models import Asset, AssetReference + +UserMetadata = dict[str, Any] | None + + +@dataclass(frozen=True) +class AssetData: + hash: str | None + size_bytes: int | None + mime_type: str | None + + +@dataclass(frozen=True) +class ReferenceData: + """Data transfer object for AssetReference.""" + + id: str + name: str + file_path: str | None + user_metadata: UserMetadata + preview_id: str | None + created_at: datetime + updated_at: datetime + system_metadata: dict[str, Any] | None = None + job_id: str | None = None + last_access_time: datetime | None = None + + +@dataclass(frozen=True) +class AssetDetailResult: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class RegisterAssetResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created: bool + + +@dataclass(frozen=True) +class IngestResult: + asset_created: bool + asset_updated: bool + ref_created: bool + ref_updated: bool + reference_id: str | None + + +class TagUsage(NamedTuple): + name: str + tag_type: str + count: int + + +@dataclass(frozen=True) +class AssetSummaryData: + ref: ReferenceData + asset: AssetData | None + tags: list[str] + + +@dataclass(frozen=True) +class ListAssetsResult: + items: list[AssetSummaryData] + total: int + + +@dataclass(frozen=True) +class DownloadResolutionResult: + abs_path: str + content_type: str + download_name: str + + +@dataclass(frozen=True) +class UploadResult: + ref: ReferenceData + asset: AssetData + tags: list[str] + created_new: bool + + +def extract_reference_data(ref: AssetReference) -> ReferenceData: + return ReferenceData( + id=ref.id, + name=ref.name, + file_path=ref.file_path, + user_metadata=ref.user_metadata, + preview_id=ref.preview_id, + system_metadata=ref.system_metadata, + job_id=ref.job_id, + created_at=ref.created_at, + updated_at=ref.updated_at, + last_access_time=ref.last_access_time, + ) + + +def extract_asset_data(asset: Asset | None) -> AssetData | None: + if asset is None: + return None + return AssetData( + hash=asset.hash, + size_bytes=asset.size_bytes, + mime_type=asset.mime_type, + ) diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py new file mode 100644 index 000000000..37b612753 --- /dev/null +++ b/app/assets/services/tagging.py @@ -0,0 +1,98 @@ +from typing import Sequence + +from app.assets.database.queries import ( + AddTagsResult, + RemoveTagsResult, + add_tags_to_reference, + get_reference_with_owner_check, + list_tags_with_usage, + remove_tags_from_reference, +) +from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets +from app.assets.services.schemas import TagUsage +from app.database.db import create_session + + +def apply_tags( + reference_id: str, + tags: list[str], + origin: str = "manual", + owner_id: str = "", +) -> AddTagsResult: + with create_session() as session: + ref_row = get_reference_with_owner_check(session, reference_id, owner_id) + + result = add_tags_to_reference( + session, + reference_id=reference_id, + tags=tags, + origin=origin, + create_if_missing=True, + reference_row=ref_row, + ) + session.commit() + + return result + + +def remove_tags( + reference_id: str, + tags: list[str], + owner_id: str = "", +) -> RemoveTagsResult: + with create_session() as session: + get_reference_with_owner_check(session, reference_id, owner_id) + + result = remove_tags_from_reference( + session, + reference_id=reference_id, + tags=tags, + ) + session.commit() + + return result + + +def list_tags( + prefix: str | None = None, + limit: int = 100, + offset: int = 0, + order: str = "count_desc", + include_zero: bool = True, + owner_id: str = "", +) -> tuple[list[TagUsage], int]: + limit = max(1, min(1000, limit)) + offset = max(0, offset) + + with create_session() as session: + rows, total = list_tags_with_usage( + session, + prefix=prefix, + limit=limit, + offset=offset, + include_zero=include_zero, + order=order, + owner_id=owner_id, + ) + + return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total + + +def list_tag_histogram( + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 100, +) -> dict[str, int]: + with create_session() as session: + return list_tag_counts_for_filtered_assets( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + ) diff --git a/app/database/db.py b/app/database/db.py index 1de8b80ed..0aab09a49 100644 --- a/app/database/db.py +++ b/app/database/db.py @@ -3,6 +3,7 @@ import os import shutil from app.logger import log_startup_warning from utils.install_util import get_missing_requirements_message +from filelock import FileLock, Timeout from comfy.cli_args import args _DB_AVAILABLE = False @@ -14,8 +15,12 @@ try: from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory - from sqlalchemy import create_engine + from sqlalchemy import create_engine, event from sqlalchemy.orm import sessionmaker + from sqlalchemy.pool import StaticPool + + from app.database.models import Base + import app.assets.database.models # noqa: F401 — register models with Base.metadata _DB_AVAILABLE = True except ImportError as e: @@ -65,9 +70,69 @@ def get_db_path(): raise ValueError(f"Unsupported database URL '{url}'.") +_db_lock = None + +def _acquire_file_lock(db_path): + """Acquire an OS-level file lock to prevent multi-process access. + + Uses filelock for cross-platform support (macOS, Linux, Windows). + The OS automatically releases the lock when the process exits, even on crashes. + """ + global _db_lock + lock_path = db_path + ".lock" + _db_lock = FileLock(lock_path) + try: + _db_lock.acquire(timeout=0) + except Timeout: + raise RuntimeError( + f"Could not acquire lock on database '{db_path}'. " + "Another ComfyUI process may already be using it. " + "Use --database-url to specify a separate database file." + ) + + +def _is_memory_db(db_url): + """Check if the database URL refers to an in-memory SQLite database.""" + return db_url in ("sqlite:///:memory:", "sqlite://") + + def init_db(): db_url = args.database_url logging.debug(f"Database URL: {db_url}") + + if _is_memory_db(db_url): + _init_memory_db(db_url) + else: + _init_file_db(db_url) + + +def _init_memory_db(db_url): + """Initialize an in-memory SQLite database using metadata.create_all. + + Alembic migrations don't work with in-memory SQLite because each + connection gets its own separate database — tables created by Alembic's + internal connection are lost immediately. + """ + engine = create_engine( + db_url, + poolclass=StaticPool, + connect_args={"check_same_thread": False}, + ) + + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + Base.metadata.create_all(engine) + + global Session + Session = sessionmaker(bind=engine) + + +def _init_file_db(db_url): + """Initialize a file-backed SQLite database using Alembic migrations.""" db_path = get_db_path() db_exists = os.path.exists(db_path) @@ -75,6 +140,14 @@ def init_db(): # Check if we need to upgrade engine = create_engine(db_url) + + # Enable foreign key enforcement for SQLite + @event.listens_for(engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + conn = engine.connect() context = MigrationContext.configure(conn) @@ -104,6 +177,12 @@ def init_db(): logging.exception("Error upgrading database: ") raise e + # Acquire an OS-level file lock after migrations are complete. + # Alembic uses its own connection, so we must wait until it's done + # before locking — otherwise our own lock blocks the migration. + conn.close() + _acquire_file_lock(db_path) + global Session Session = sessionmaker(bind=engine) diff --git a/app/database/models.py b/app/database/models.py index 6facfb8f2..b02856f6e 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,14 +1,30 @@ -from sqlalchemy.orm import declarative_base +from typing import Any +from datetime import datetime +from sqlalchemy import MetaData +from sqlalchemy.orm import DeclarativeBase -Base = declarative_base() +NAMING_CONVENTION = { + "ix": "ix_%(table_name)s_%(column_0_N_name)s", + "uq": "uq_%(table_name)s_%(column_0_N_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", +} +class Base(DeclarativeBase): + metadata = MetaData(naming_convention=NAMING_CONVENTION) -def to_dict(obj): +def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: fields = obj.__table__.columns.keys() - return { - field: (val.to_dict() if hasattr(val, "to_dict") else val) - for field in fields - if (val := getattr(obj, field)) - } + out: dict[str, Any] = {} + for field in fields: + val = getattr(obj, field) + if val is None and not include_none: + continue + if isinstance(val, datetime): + out[field] = val.isoformat() + else: + out[field] = val + return out # TODO: Define models here diff --git a/app/frontend_management.py b/app/frontend_management.py index bdaa85812..f753ef0de 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -17,7 +17,7 @@ from importlib.metadata import version import requests from typing_extensions import NotRequired -from utils.install_util import get_missing_requirements_message, requirements_path +from utils.install_util import get_missing_requirements_message, get_required_packages_versions from comfy.cli_args import DEFAULT_VERSION_STRING import app.logger @@ -45,25 +45,7 @@ def get_installed_frontend_version(): def get_required_frontend_version(): - """Get the required frontend version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-frontend-package=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-frontend-package not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required frontend version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-frontend-package", None) def check_frontend_version(): @@ -217,25 +199,7 @@ class FrontendManager: @classmethod def get_required_templates_version(cls) -> str: - """Get the required workflow templates version from requirements.txt.""" - try: - with open(requirements_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if line.startswith("comfyui-workflow-templates=="): - version_str = line.split("==")[-1] - if not is_valid_version(version_str): - logging.error(f"Invalid templates version format in requirements.txt: {version_str}") - return None - return version_str - logging.error("comfyui-workflow-templates not found in requirements.txt") - return None - except FileNotFoundError: - logging.error("requirements.txt not found. Cannot determine required templates version.") - return None - except Exception as e: - logging.error(f"Error reading requirements.txt: {e}") - return None + return get_required_packages_versions().get("comfyui-workflow-templates", None) @classmethod def default_frontend_path(cls) -> str: diff --git a/app/model_manager.py b/app/model_manager.py index ab36bca74..f124d1117 100644 --- a/app/model_manager.py +++ b/app/model_manager.py @@ -44,7 +44,7 @@ class ModelFileManager: @routes.get("/experiment/models/{folder}") async def get_all_models(request): folder = request.match_info.get("folder", None) - if not folder in folder_paths.folder_names_and_paths: + if folder not in folder_paths.folder_names_and_paths: return web.Response(status=404) files = self.get_model_file_list(folder) return web.json_response(files) @@ -55,7 +55,7 @@ class ModelFileManager: path_index = int(request.match_info.get("path_index", None)) filename = request.match_info.get("filename", None) - if not folder_name in folder_paths.folder_names_and_paths: + if folder_name not in folder_paths.folder_names_and_paths: return web.Response(status=404) folders = folder_paths.folder_names_and_paths[folder_name] diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py new file mode 100644 index 000000000..d9aab5b22 --- /dev/null +++ b/app/node_replace_manager.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from aiohttp import web + +from typing import TYPE_CHECKING, TypedDict +if TYPE_CHECKING: + from comfy_api.latest._io_public import NodeReplace + +from comfy_execution.graph_utils import is_link +import nodes + +class NodeStruct(TypedDict): + inputs: dict[str, str | int | float | bool | tuple[str, int]] + class_type: str + _meta: dict[str, str] + +def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct: + new_node_struct = node_struct.copy() + if empty_inputs: + new_node_struct["inputs"] = {} + else: + new_node_struct["inputs"] = node_struct["inputs"].copy() + new_node_struct["_meta"] = node_struct["_meta"].copy() + return new_node_struct + + +class NodeReplaceManager: + """Manages node replacement registrations.""" + + def __init__(self): + self._replacements: dict[str, list[NodeReplace]] = {} + + def register(self, node_replace: NodeReplace): + """Register a node replacement mapping.""" + self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace) + + def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None: + """Get replacements for an old node ID.""" + return self._replacements.get(old_node_id) + + def has_replacement(self, old_node_id: str) -> bool: + """Check if a replacement exists for an old node ID.""" + return old_node_id in self._replacements + + def apply_replacements(self, prompt: dict[str, NodeStruct]): + connections: dict[str, list[tuple[str, str, int]]] = {} + need_replacement: set[str] = set() + for node_number, node_struct in prompt.items(): + if "class_type" not in node_struct or "inputs" not in node_struct: + continue + class_type = node_struct["class_type"] + # need replacement if not in NODE_CLASS_MAPPINGS and has replacement + if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type): + need_replacement.add(node_number) + # keep track of connections + for input_id, input_value in node_struct["inputs"].items(): + if is_link(input_value): + conn_number = input_value[0] + connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1])) + for node_number in need_replacement: + node_struct = prompt[node_number] + class_type = node_struct["class_type"] + replacements = self.get_replacement(class_type) + if replacements is None: + continue + # just use the first replacement + replacement = replacements[0] + new_node_id = replacement.new_node_id + # if replacement is not a valid node, skip trying to replace it as will only cause confusion + if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys(): + continue + # first, replace node id (class_type) + new_node_struct = copy_node_struct(node_struct, empty_inputs=True) + new_node_struct["class_type"] = new_node_id + # TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema + # second, replace inputs + if replacement.input_mapping is not None: + for input_map in replacement.input_mapping: + if "set_value" in input_map: + new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"] + elif "old_id" in input_map: + new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]] + # finalize input replacement + prompt[node_number] = new_node_struct + # third, replace outputs + if replacement.output_mapping is not None: + # re-mapping outputs requires changing the input values of nodes that receive connections from this one + if node_number in connections: + for conns in connections[node_number]: + conn_node_number, conn_input_id, old_output_idx = conns + for output_map in replacement.output_mapping: + if output_map["old_idx"] == old_output_idx: + new_output_idx = output_map["new_idx"] + previous_input = prompt[conn_node_number]["inputs"][conn_input_id] + previous_input[1] = new_output_idx + + def as_dict(self): + """Serialize all replacements to dict.""" + return { + k: [v.as_dict() for v in v_list] + for k, v_list in self._replacements.items() + } + + def add_routes(self, routes): + @routes.get("/node_replacements") + async def get_node_replacements(request): + return web.json_response(self.as_dict()) diff --git a/app/subgraph_manager.py b/app/subgraph_manager.py index dbe404541..08ad8c302 100644 --- a/app/subgraph_manager.py +++ b/app/subgraph_manager.py @@ -10,6 +10,7 @@ import hashlib class Source: custom_node = "custom_node" + templates = "templates" class SubgraphEntry(TypedDict): source: str @@ -38,9 +39,21 @@ class CustomNodeSubgraphEntryInfo(TypedDict): class SubgraphManager: def __init__(self): self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None + self.cached_blueprint_subgraphs: dict[SubgraphEntry] | None = None + + def _create_entry(self, file: str, source: str, node_pack: str) -> tuple[str, SubgraphEntry]: + """Create a subgraph entry from a file path. Expects normalized path (forward slashes).""" + entry_id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() + entry: SubgraphEntry = { + "source": source, + "name": os.path.splitext(os.path.basename(file))[0], + "path": file, + "info": {"node_pack": node_pack}, + } + return entry_id, entry async def load_entry_data(self, entry: SubgraphEntry): - with open(entry['path'], 'r') as f: + with open(entry['path'], 'r', encoding='utf-8') as f: entry['data'] = f.read() return entry @@ -60,53 +73,60 @@ class SubgraphManager: return entries async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): - # if not forced to reload and cached, return cache + """Load subgraphs from custom nodes.""" if not force_reload and self.cached_custom_node_subgraphs is not None: return self.cached_custom_node_subgraphs - # Load subgraphs from custom nodes - subfolder = "subgraphs" - subgraphs_dict: dict[SubgraphEntry] = {} + subgraphs_dict: dict[SubgraphEntry] = {} for folder in folder_paths.get_folder_paths("custom_nodes"): - pattern = os.path.join(folder, f"*/{subfolder}/*.json") - matched_files = glob.glob(pattern) - for file in matched_files: - # replace backslashes with forward slashes + pattern = os.path.join(folder, "*/subgraphs/*.json") + for file in glob.glob(pattern): file = file.replace('\\', '/') - info: CustomNodeSubgraphEntryInfo = { - "node_pack": "custom_nodes." + file.split('/')[-3] - } - source = Source.custom_node - # hash source + path to make sure id will be as unique as possible, but - # reproducible across backend reloads - id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() - entry: SubgraphEntry = { - "source": Source.custom_node, - "name": os.path.splitext(os.path.basename(file))[0], - "path": file, - "info": info, - } - subgraphs_dict[id] = entry + node_pack = "custom_nodes." + file.split('/')[-3] + entry_id, entry = self._create_entry(file, Source.custom_node, node_pack) + subgraphs_dict[entry_id] = entry + self.cached_custom_node_subgraphs = subgraphs_dict return subgraphs_dict - async def get_custom_node_subgraph(self, id: str, loadedModules): - subgraphs = await self.get_custom_node_subgraphs(loadedModules) - entry: SubgraphEntry = subgraphs.get(id, None) - if entry is not None and entry.get('data', None) is None: + async def get_blueprint_subgraphs(self, force_reload=False): + """Load subgraphs from the blueprints directory.""" + if not force_reload and self.cached_blueprint_subgraphs is not None: + return self.cached_blueprint_subgraphs + + subgraphs_dict: dict[SubgraphEntry] = {} + blueprints_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'blueprints') + + if os.path.exists(blueprints_dir): + for file in glob.glob(os.path.join(blueprints_dir, "*.json")): + file = file.replace('\\', '/') + entry_id, entry = self._create_entry(file, Source.templates, "comfyui") + subgraphs_dict[entry_id] = entry + + self.cached_blueprint_subgraphs = subgraphs_dict + return subgraphs_dict + + async def get_all_subgraphs(self, loadedModules, force_reload=False): + """Get all subgraphs from all sources (custom nodes and blueprints).""" + custom_node_subgraphs = await self.get_custom_node_subgraphs(loadedModules, force_reload) + blueprint_subgraphs = await self.get_blueprint_subgraphs(force_reload) + return {**custom_node_subgraphs, **blueprint_subgraphs} + + async def get_subgraph(self, id: str, loadedModules): + """Get a specific subgraph by ID from any source.""" + entry = (await self.get_all_subgraphs(loadedModules)).get(id) + if entry is not None and entry.get('data') is None: await self.load_entry_data(entry) return entry def add_routes(self, routes, loadedModules): @routes.get("/global_subgraphs") async def get_global_subgraphs(request): - subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules) - # NOTE: we may want to include other sources of global subgraphs such as templates in the future; - # that's the reasoning for the current implementation + subgraphs_dict = await self.get_all_subgraphs(loadedModules) return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) @routes.get("/global_subgraphs/{id}") async def get_global_subgraph(request): id = request.match_info.get("id", None) - subgraph = await self.get_custom_node_subgraph(id, loadedModules) + subgraph = await self.get_subgraph(id, loadedModules) return web.json_response(await self.sanitize_entry(subgraph)) diff --git a/app/user_manager.py b/app/user_manager.py index e2c00dab2..e18afb71b 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -6,6 +6,7 @@ import uuid import glob import shutil import logging +import tempfile from aiohttp import web from urllib import parse from comfy.cli_args import args @@ -377,8 +378,15 @@ class UserManager(): try: body = await request.read() - with open(path, "wb") as f: - f.write(body) + dir_name = os.path.dirname(path) + fd, tmp_path = tempfile.mkstemp(dir=dir_name) + try: + with os.fdopen(fd, "wb") as f: + f.write(body) + os.replace(tmp_path, path) + except: + os.unlink(tmp_path) + raise except OSError as e: logging.warning(f"Error saving file '{path}': {e}") return web.Response( diff --git a/blueprints/.glsl/Brightness_and_Contrast_1.frag b/blueprints/.glsl/Brightness_and_Contrast_1.frag new file mode 100644 index 000000000..da5424080 --- /dev/null +++ b/blueprints/.glsl/Brightness_and_Contrast_1.frag @@ -0,0 +1,44 @@ +#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform float u_float0; // Brightness slider -100..100 +uniform float u_float1; // Contrast slider -100..100 + +in vec2 v_texCoord; +out vec4 fragColor; + +const float MID_GRAY = 0.18; // 18% reflectance + +// sRGB gamma 2.2 approximation +vec3 srgbToLinear(vec3 c) { + return pow(max(c, 0.0), vec3(2.2)); +} + +vec3 linearToSrgb(vec3 c) { + return pow(max(c, 0.0), vec3(1.0/2.2)); +} + +float mapBrightness(float b) { + return clamp(b / 100.0, -1.0, 1.0); +} + +float mapContrast(float c) { + return clamp(c / 100.0 + 1.0, 0.0, 2.0); +} + +void main() { + vec4 orig = texture(u_image0, v_texCoord); + + float brightness = mapBrightness(u_float0); + float contrast = mapContrast(u_float1); + + vec3 lin = srgbToLinear(orig.rgb); + + lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY; + + // Convert back to sRGB + vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0)); + + fragColor = vec4(result, orig.a); +} diff --git a/blueprints/.glsl/Chromatic_Aberration_16.frag b/blueprints/.glsl/Chromatic_Aberration_16.frag new file mode 100644 index 000000000..09a271146 --- /dev/null +++ b/blueprints/.glsl/Chromatic_Aberration_16.frag @@ -0,0 +1,72 @@ +#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform vec2 u_resolution; +uniform int u_int0; // Mode +uniform float u_float0; // Amount (0 to 100) + +in vec2 v_texCoord; +out vec4 fragColor; + +const int MODE_LINEAR = 0; +const int MODE_RADIAL = 1; +const int MODE_BARREL = 2; +const int MODE_SWIRL = 3; +const int MODE_DIAGONAL = 4; + +const float AMOUNT_SCALE = 0.0005; +const float RADIAL_MULT = 4.0; +const float BARREL_MULT = 8.0; +const float INV_SQRT2 = 0.70710678118; + +void main() { + vec2 uv = v_texCoord; + vec4 original = texture(u_image0, uv); + + float amount = u_float0 * AMOUNT_SCALE; + + if (amount < 0.000001) { + fragColor = original; + return; + } + + // Aspect-corrected coordinates for circular effects + float aspect = u_resolution.x / u_resolution.y; + vec2 centered = uv - 0.5; + vec2 corrected = vec2(centered.x * aspect, centered.y); + float r = length(corrected); + vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0); + vec2 offset = vec2(0.0); + + if (u_int0 == MODE_LINEAR) { + // Horizontal shift (no aspect correction needed) + offset = vec2(amount, 0.0); + } + else if (u_int0 == MODE_RADIAL) { + // Outward from center, stronger at edges + offset = dir * r * amount * RADIAL_MULT; + offset.x /= aspect; // Convert back to UV space + } + else if (u_int0 == MODE_BARREL) { + // Lens distortion simulation (r² falloff) + offset = dir * r * r * amount * BARREL_MULT; + offset.x /= aspect; // Convert back to UV space + } + else if (u_int0 == MODE_SWIRL) { + // Perpendicular to radial (rotational aberration) + vec2 perp = vec2(-dir.y, dir.x); + offset = perp * r * amount * RADIAL_MULT; + offset.x /= aspect; // Convert back to UV space + } + else if (u_int0 == MODE_DIAGONAL) { + // 45° offset (no aspect correction needed) + offset = vec2(amount, amount) * INV_SQRT2; + } + + float red = texture(u_image0, uv + offset).r; + float green = original.g; + float blue = texture(u_image0, uv - offset).b; + + fragColor = vec4(red, green, blue, original.a); +} \ No newline at end of file diff --git a/blueprints/.glsl/Color_Adjustment_15.frag b/blueprints/.glsl/Color_Adjustment_15.frag new file mode 100644 index 000000000..697525f14 --- /dev/null +++ b/blueprints/.glsl/Color_Adjustment_15.frag @@ -0,0 +1,78 @@ +#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform float u_float0; // temperature (-100 to 100) +uniform float u_float1; // tint (-100 to 100) +uniform float u_float2; // vibrance (-100 to 100) +uniform float u_float3; // saturation (-100 to 100) + +in vec2 v_texCoord; +out vec4 fragColor; + +const float INPUT_SCALE = 0.01; +const float TEMP_TINT_PRIMARY = 0.3; +const float TEMP_TINT_SECONDARY = 0.15; +const float VIBRANCE_BOOST = 2.0; +const float SATURATION_BOOST = 2.0; +const float SKIN_PROTECTION = 0.5; +const float EPSILON = 0.001; +const vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114); + +void main() { + vec4 tex = texture(u_image0, v_texCoord); + vec3 color = tex.rgb; + + // Scale inputs: -100/100 → -1/1 + float temperature = u_float0 * INPUT_SCALE; + float tint = u_float1 * INPUT_SCALE; + float vibrance = u_float2 * INPUT_SCALE; + float saturation = u_float3 * INPUT_SCALE; + + // Temperature (warm/cool): positive = warm, negative = cool + color.r += temperature * TEMP_TINT_PRIMARY; + color.b -= temperature * TEMP_TINT_PRIMARY; + + // Tint (green/magenta): positive = green, negative = magenta + color.g += tint * TEMP_TINT_PRIMARY; + color.r -= tint * TEMP_TINT_SECONDARY; + color.b -= tint * TEMP_TINT_SECONDARY; + + // Single clamp after temperature/tint + color = clamp(color, 0.0, 1.0); + + // Vibrance with skin protection + if (vibrance != 0.0) { + float maxC = max(color.r, max(color.g, color.b)); + float minC = min(color.r, min(color.g, color.b)); + float sat = maxC - minC; + float gray = dot(color, LUMA_WEIGHTS); + + if (vibrance < 0.0) { + // Desaturate: -100 → gray + color = mix(vec3(gray), color, 1.0 + vibrance); + } else { + // Boost less saturated colors more + float vibranceAmt = vibrance * (1.0 - sat); + + // Branchless skin tone protection + float isWarmTone = step(color.b, color.g) * step(color.g, color.r); + float warmth = (color.r - color.b) / max(maxC, EPSILON); + float skinTone = isWarmTone * warmth * sat * (1.0 - sat); + vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION); + + color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST); + } + } + + // Saturation + if (saturation != 0.0) { + float gray = dot(color, LUMA_WEIGHTS); + float satMix = saturation < 0.0 + ? 1.0 + saturation // -100 → gray + : 1.0 + saturation * SATURATION_BOOST; // +100 → 3x boost + color = mix(vec3(gray), color, satMix); + } + + fragColor = vec4(clamp(color, 0.0, 1.0), tex.a); +} \ No newline at end of file diff --git a/blueprints/.glsl/Edge-Preserving_Blur_128.frag b/blueprints/.glsl/Edge-Preserving_Blur_128.frag new file mode 100644 index 000000000..f269aebd6 --- /dev/null +++ b/blueprints/.glsl/Edge-Preserving_Blur_128.frag @@ -0,0 +1,94 @@ +#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform float u_float0; // Blur radius (0–20, default ~5) +uniform float u_float1; // Edge threshold (0–100, default ~30) +uniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels) + +in vec2 v_texCoord; +out vec4 fragColor; + +const int MAX_RADIUS = 20; +const float EPSILON = 0.0001; + +// Perceptual luminance +float getLuminance(vec3 rgb) { + return dot(rgb, vec3(0.299, 0.587, 0.114)); +} + +vec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius, + float sigmaSpatial, float sigmaColor) +{ + vec4 center = texture(u_image0, uv); + vec3 centerRGB = center.rgb; + + float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial); + float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON); + + vec3 sumRGB = vec3(0.0); + float sumWeight = 0.0; + + int step = max(u_int0, 1); + float radius2 = float(radius * radius); + + for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) { + if (dy < -radius || dy > radius) continue; + if (abs(dy) % step != 0) continue; + + for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) { + if (dx < -radius || dx > radius) continue; + if (abs(dx) % step != 0) continue; + + vec2 offset = vec2(float(dx), float(dy)); + float dist2 = dot(offset, offset); + if (dist2 > radius2) continue; + + vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb; + + // Spatial Gaussian + float spatialWeight = exp(dist2 * invSpatial2); + + // Perceptual color distance (weighted RGB) + vec3 diff = sampleRGB - centerRGB; + float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114)); + float colorWeight = exp(colorDist * invColor2); + + float w = spatialWeight * colorWeight; + sumRGB += sampleRGB * w; + sumWeight += w; + } + } + + vec3 resultRGB = sumRGB / max(sumWeight, EPSILON); + return vec4(resultRGB, center.a); // preserve center alpha +} + +void main() { + vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0)); + + float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS)); + int radius = int(radiusF + 0.5); + + if (radius == 0) { + fragColor = texture(u_image0, v_texCoord); + return; + } + + // Edge threshold → color sigma + // Squared curve for better low-end control + float t = clamp(u_float1, 0.0, 100.0) / 100.0; + t *= t; + float sigmaColor = mix(0.01, 0.5, t); + + // Spatial sigma tied to radius + float sigmaSpatial = max(radiusF * 0.75, 0.5); + + fragColor = bilateralFilter( + v_texCoord, + texelSize, + radius, + sigmaSpatial, + sigmaColor + ); +} \ No newline at end of file diff --git a/blueprints/.glsl/Film_Grain_15.frag b/blueprints/.glsl/Film_Grain_15.frag new file mode 100644 index 000000000..21585825b --- /dev/null +++ b/blueprints/.glsl/Film_Grain_15.frag @@ -0,0 +1,124 @@ +#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform vec2 u_resolution; +uniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8 +uniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain +uniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain +uniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only +uniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy + +in vec2 v_texCoord; +layout(location = 0) out vec4 fragColor0; + +// High-quality integer hash (pcg-like) +uint pcg(uint v) { + uint state = v * 747796405u + 2891336453u; + uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u; + return (word >> 22u) ^ word; +} + +// 2D -> 1D hash input +uint hash2d(uvec2 p) { + return pcg(p.x + pcg(p.y)); +} + +// Hash to float [0, 1] +float hashf(uvec2 p) { + return float(hash2d(p)) / float(0xffffffffu); +} + +// Hash to float with offset (for RGB channels) +float hashf(uvec2 p, uint offset) { + return float(pcg(hash2d(p) + offset)) / float(0xffffffffu); +} + +// Convert uniform [0,1] to roughly Gaussian distribution +// Using simple approximation: average of multiple samples +float toGaussian(uvec2 p) { + float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u); + return (sum - 2.0) * 0.7; // Centered, scaled +} + +float toGaussian(uvec2 p, uint offset) { + float sum = hashf(p, offset) + hashf(p, offset + 1u) + + hashf(p, offset + 2u) + hashf(p, offset + 3u); + return (sum - 2.0) * 0.7; +} + +// Smooth noise with better interpolation +float smoothNoise(vec2 p) { + vec2 i = floor(p); + vec2 f = fract(p); + + // Quintic interpolation (less banding than cubic) + f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0); + + uvec2 ui = uvec2(i); + float a = toGaussian(ui); + float b = toGaussian(ui + uvec2(1u, 0u)); + float c = toGaussian(ui + uvec2(0u, 1u)); + float d = toGaussian(ui + uvec2(1u, 1u)); + + return mix(mix(a, b, f.x), mix(c, d, f.x), f.y); +} + +float smoothNoise(vec2 p, uint offset) { + vec2 i = floor(p); + vec2 f = fract(p); + + f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0); + + uvec2 ui = uvec2(i); + float a = toGaussian(ui, offset); + float b = toGaussian(ui + uvec2(1u, 0u), offset); + float c = toGaussian(ui + uvec2(0u, 1u), offset); + float d = toGaussian(ui + uvec2(1u, 1u), offset); + + return mix(mix(a, b, f.x), mix(c, d, f.x), f.y); +} + +void main() { + vec4 color = texture(u_image0, v_texCoord); + + // Luminance (Rec.709) + float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722)); + + // Grain UV (resolution-independent) + vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01); + uvec2 grainPixel = uvec2(grainUV); + + float g; + vec3 grainRGB; + + if (u_int0 == 1) { + // Grainy mode: pure hash noise (no interpolation = no banding) + g = toGaussian(grainPixel); + grainRGB = vec3( + toGaussian(grainPixel, 100u), + toGaussian(grainPixel, 200u), + toGaussian(grainPixel, 300u) + ); + } else { + // Smooth mode: interpolated with quintic curve + g = smoothNoise(grainUV); + grainRGB = vec3( + smoothNoise(grainUV, 100u), + smoothNoise(grainUV, 200u), + smoothNoise(grainUV, 300u) + ); + } + + // Luminance weighting (less grain in highlights) + float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0)); + + // Strength + float strength = u_float0 * 0.15; + + // Color vs monochrome grain + vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0)); + + color.rgb += grainColor * strength * lumWeight; + fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a); +} diff --git a/blueprints/.glsl/Glow_30.frag b/blueprints/.glsl/Glow_30.frag new file mode 100644 index 000000000..0ee152628 --- /dev/null +++ b/blueprints/.glsl/Glow_30.frag @@ -0,0 +1,133 @@ +#version 300 es +precision mediump float; + +uniform sampler2D u_image0; +uniform vec2 u_resolution; +uniform int u_int0; // Blend mode +uniform int u_int1; // Color tint +uniform float u_float0; // Intensity +uniform float u_float1; // Radius +uniform float u_float2; // Threshold + +in vec2 v_texCoord; +out vec4 fragColor; + +const int BLEND_ADD = 0; +const int BLEND_SCREEN = 1; +const int BLEND_SOFT = 2; +const int BLEND_OVERLAY = 3; +const int BLEND_LIGHTEN = 4; + +const float GOLDEN_ANGLE = 2.39996323; +const int MAX_SAMPLES = 48; +const vec3 LUMA = vec3(0.299, 0.587, 0.114); + +float hash(vec2 p) { + p = fract(p * vec2(123.34, 456.21)); + p += dot(p, p + 45.32); + return fract(p.x * p.y); +} + +vec3 hexToRgb(int h) { + return vec3( + float((h >> 16) & 255), + float((h >> 8) & 255), + float(h & 255) + ) * (1.0 / 255.0); +} + +vec3 blend(vec3 base, vec3 glow, int mode) { + if (mode == BLEND_SCREEN) { + return 1.0 - (1.0 - base) * (1.0 - glow); + } + if (mode == BLEND_SOFT) { + return mix( + base - (1.0 - 2.0 * glow) * base * (1.0 - base), + base + (2.0 * glow - 1.0) * (sqrt(base) - base), + step(0.5, glow) + ); + } + if (mode == BLEND_OVERLAY) { + return mix( + 2.0 * base * glow, + 1.0 - 2.0 * (1.0 - base) * (1.0 - glow), + step(0.5, base) + ); + } + if (mode == BLEND_LIGHTEN) { + return max(base, glow); + } + return base + glow; +} + +void main() { + vec4 original = texture(u_image0, v_texCoord); + + float intensity = u_float0 * 0.05; + float radius = u_float1 * u_float1 * 0.012; + + if (intensity < 0.001 || radius < 0.1) { + fragColor = original; + return; + } + + float threshold = 1.0 - u_float2 * 0.01; + float t0 = threshold - 0.15; + float t1 = threshold + 0.15; + + vec2 texelSize = 1.0 / u_resolution; + float radius2 = radius * radius; + + float sampleScale = clamp(radius * 0.75, 0.35, 1.0); + int samples = int(float(MAX_SAMPLES) * sampleScale); + + float noise = hash(gl_FragCoord.xy); + float angleOffset = noise * GOLDEN_ANGLE; + float radiusJitter = 0.85 + noise * 0.3; + + float ca = cos(GOLDEN_ANGLE); + float sa = sin(GOLDEN_ANGLE); + vec2 dir = vec2(cos(angleOffset), sin(angleOffset)); + + vec3 glow = vec3(0.0); + float totalWeight = 0.0; + + // Center tap + float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA)); + glow += original.rgb * centerMask * 2.0; + totalWeight += 2.0; + + for (int i = 1; i < MAX_SAMPLES; i++) { + if (i >= samples) break; + + float fi = float(i); + float dist = sqrt(fi / float(samples)) * radius * radiusJitter; + + vec2 offset = dir * dist * texelSize; + vec3 c = texture(u_image0, v_texCoord + offset).rgb; + float mask = smoothstep(t0, t1, dot(c, LUMA)); + + float w = 1.0 - (dist * dist) / (radius2 * 1.5); + w = max(w, 0.0); + w *= w; + + glow += c * mask * w; + totalWeight += w; + + dir = vec2( + dir.x * ca - dir.y * sa, + dir.x * sa + dir.y * ca + ); + } + + glow *= intensity / max(totalWeight, 0.001); + + if (u_int1 > 0) { + glow *= hexToRgb(u_int1); + } + + vec3 result = blend(original.rgb, glow, u_int0); + result += (noise - 0.5) * (1.0 / 255.0); + + fragColor = vec4(clamp(result, 0.0, 1.0), original.a); +} \ No newline at end of file diff --git a/blueprints/.glsl/Hue_and_Saturation_1.frag b/blueprints/.glsl/Hue_and_Saturation_1.frag new file mode 100644 index 000000000..0fa6810af --- /dev/null +++ b/blueprints/.glsl/Hue_and_Saturation_1.frag @@ -0,0 +1,222 @@ +#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize +uniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV +uniform float u_float0; // Hue (-180 to 180) +uniform float u_float1; // Saturation (-100 to 100) +uniform float u_float2; // Lightness/Brightness (-100 to 100) +uniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges + +in vec2 v_texCoord; +out vec4 fragColor; + +// Color range modes +const int MODE_MASTER = 0; +const int MODE_RED = 1; +const int MODE_YELLOW = 2; +const int MODE_GREEN = 3; +const int MODE_CYAN = 4; +const int MODE_BLUE = 5; +const int MODE_MAGENTA = 6; +const int MODE_COLORIZE = 7; + +// Color space modes +const int COLORSPACE_HSL = 0; +const int COLORSPACE_HSB = 1; + +const float EPSILON = 0.0001; + +//============================================================================= +// RGB <-> HSL Conversions +//============================================================================= + +vec3 rgb2hsl(vec3 c) { + float maxC = max(max(c.r, c.g), c.b); + float minC = min(min(c.r, c.g), c.b); + float delta = maxC - minC; + + float h = 0.0; + float s = 0.0; + float l = (maxC + minC) * 0.5; + + if (delta > EPSILON) { + s = l < 0.5 + ? delta / (maxC + minC) + : delta / (2.0 - maxC - minC); + + if (maxC == c.r) { + h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0); + } else if (maxC == c.g) { + h = (c.b - c.r) / delta + 2.0; + } else { + h = (c.r - c.g) / delta + 4.0; + } + h /= 6.0; + } + + return vec3(h, s, l); +} + +float hue2rgb(float p, float q, float t) { + t = fract(t); + if (t < 1.0/6.0) return p + (q - p) * 6.0 * t; + if (t < 0.5) return q; + if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0; + return p; +} + +vec3 hsl2rgb(vec3 hsl) { + if (hsl.y < EPSILON) return vec3(hsl.z); + + float q = hsl.z < 0.5 + ? hsl.z * (1.0 + hsl.y) + : hsl.z + hsl.y - hsl.z * hsl.y; + float p = 2.0 * hsl.z - q; + + return vec3( + hue2rgb(p, q, hsl.x + 1.0/3.0), + hue2rgb(p, q, hsl.x), + hue2rgb(p, q, hsl.x - 1.0/3.0) + ); +} + +vec3 rgb2hsb(vec3 c) { + float maxC = max(max(c.r, c.g), c.b); + float minC = min(min(c.r, c.g), c.b); + float delta = maxC - minC; + + float h = 0.0; + float s = (maxC > EPSILON) ? delta / maxC : 0.0; + float b = maxC; + + if (delta > EPSILON) { + if (maxC == c.r) { + h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0); + } else if (maxC == c.g) { + h = (c.b - c.r) / delta + 2.0; + } else { + h = (c.r - c.g) / delta + 4.0; + } + h /= 6.0; + } + + return vec3(h, s, b); +} + +vec3 hsb2rgb(vec3 hsb) { + vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0); + return hsb.z * mix(vec3(1.0), rgb, hsb.y); +} + +//============================================================================= +// Color Range Weight Calculation +//============================================================================= + +float hueDistance(float a, float b) { + float d = abs(a - b); + return min(d, 1.0 - d); +} + +float getHueWeight(float hue, float center, float overlap) { + float baseWidth = 1.0 / 6.0; + float feather = baseWidth * overlap; + + float d = hueDistance(hue, center); + + float inner = baseWidth * 0.5; + float outer = inner + feather; + + return 1.0 - smoothstep(inner, outer, d); +} + +float getModeWeight(float hue, int mode, float overlap) { + if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0; + + if (mode == MODE_RED) { + return max( + getHueWeight(hue, 0.0, overlap), + getHueWeight(hue, 1.0, overlap) + ); + } + + float center = float(mode - 1) / 6.0; + return getHueWeight(hue, center, overlap); +} + +//============================================================================= +// Adjustment Functions +//============================================================================= + +float adjustLightness(float l, float amount) { + return amount > 0.0 + ? l + (1.0 - l) * amount + : l + l * amount; +} + +float adjustBrightness(float b, float amount) { + return clamp(b + amount, 0.0, 1.0); +} + +float adjustSaturation(float s, float amount) { + return amount > 0.0 + ? s + (1.0 - s) * amount + : s + s * amount; +} + +vec3 colorize(vec3 rgb, float hue, float sat, float light) { + float lum = dot(rgb, vec3(0.299, 0.587, 0.114)); + float l = adjustLightness(lum, light); + + vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0)); + return hsl2rgb(hsl); +} + +//============================================================================= +// Main +//============================================================================= + +void main() { + vec4 original = texture(u_image0, v_texCoord); + + float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5 + float satAmount = u_float1 / 100.0; // -100..100 -> -1..1 + float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1 + float overlap = u_float3 / 100.0; // 0..100 -> 0..1 + + vec3 result; + + if (u_int0 == MODE_COLORIZE) { + result = colorize(original.rgb, hueShift, satAmount, lightAmount); + fragColor = vec4(result, original.a); + return; + } + + vec3 hsx = (u_int1 == COLORSPACE_HSL) + ? rgb2hsl(original.rgb) + : rgb2hsb(original.rgb); + + float weight = getModeWeight(hsx.x, u_int0, overlap); + + if (u_int0 != MODE_MASTER && hsx.y < EPSILON) { + weight = 0.0; + } + + if (weight > EPSILON) { + float h = fract(hsx.x + hueShift * weight); + float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0); + float v = (u_int1 == COLORSPACE_HSL) + ? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0) + : clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0); + + vec3 adjusted = vec3(h, s, v); + result = (u_int1 == COLORSPACE_HSL) + ? hsl2rgb(adjusted) + : hsb2rgb(adjusted); + } else { + result = original.rgb; + } + + fragColor = vec4(result, original.a); +} diff --git a/blueprints/.glsl/Image_Blur_1.frag b/blueprints/.glsl/Image_Blur_1.frag new file mode 100644 index 000000000..83238111d --- /dev/null +++ b/blueprints/.glsl/Image_Blur_1.frag @@ -0,0 +1,111 @@ +#version 300 es +#pragma passes 2 +precision highp float; + +// Blur type constants +const int BLUR_GAUSSIAN = 0; +const int BLUR_BOX = 1; +const int BLUR_RADIAL = 2; + +// Radial blur config +const int RADIAL_SAMPLES = 12; +const float RADIAL_STRENGTH = 0.0003; + +uniform sampler2D u_image0; +uniform vec2 u_resolution; +uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL) +uniform float u_float0; // Blur radius/amount +uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical) + +in vec2 v_texCoord; +layout(location = 0) out vec4 fragColor0; + +float gaussian(float x, float sigma) { + return exp(-(x * x) / (2.0 * sigma * sigma)); +} + +void main() { + vec2 texelSize = 1.0 / u_resolution; + float radius = max(u_float0, 0.0); + + // Radial (angular) blur - single pass, doesn't use separable + if (u_int0 == BLUR_RADIAL) { + // Only execute on first pass + if (u_pass > 0) { + fragColor0 = texture(u_image0, v_texCoord); + return; + } + + vec2 center = vec2(0.5); + vec2 dir = v_texCoord - center; + float dist = length(dir); + + if (dist < 1e-4) { + fragColor0 = texture(u_image0, v_texCoord); + return; + } + + vec4 sum = vec4(0.0); + float totalWeight = 0.0; + float angleStep = radius * RADIAL_STRENGTH; + + dir /= dist; + + float cosStep = cos(angleStep); + float sinStep = sin(angleStep); + + float negAngle = -float(RADIAL_SAMPLES) * angleStep; + vec2 rotDir = vec2( + dir.x * cos(negAngle) - dir.y * sin(negAngle), + dir.x * sin(negAngle) + dir.y * cos(negAngle) + ); + + for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) { + vec2 uv = center + rotDir * dist; + float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES); + sum += texture(u_image0, uv) * w; + totalWeight += w; + + rotDir = vec2( + rotDir.x * cosStep - rotDir.y * sinStep, + rotDir.x * sinStep + rotDir.y * cosStep + ); + } + + fragColor0 = sum / max(totalWeight, 0.001); + return; + } + + // Separable Gaussian / Box blur + int samples = int(ceil(radius)); + + if (samples == 0) { + fragColor0 = texture(u_image0, v_texCoord); + return; + } + + // Direction: pass 0 = horizontal, pass 1 = vertical + vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0); + + vec4 color = vec4(0.0); + float totalWeight = 0.0; + float sigma = radius / 2.0; + + for (int i = -samples; i <= samples; i++) { + vec2 offset = dir * float(i) * texelSize; + vec4 sample_color = texture(u_image0, v_texCoord + offset); + + float weight; + if (u_int0 == BLUR_GAUSSIAN) { + weight = gaussian(float(i), sigma); + } else { + // BLUR_BOX + weight = 1.0; + } + + color += sample_color * weight; + totalWeight += weight; + } + + fragColor0 = color / totalWeight; +} diff --git a/blueprints/.glsl/Image_Channels_23.frag b/blueprints/.glsl/Image_Channels_23.frag new file mode 100644 index 000000000..76d70af13 --- /dev/null +++ b/blueprints/.glsl/Image_Channels_23.frag @@ -0,0 +1,19 @@ +#version 300 es +precision highp float; + +uniform sampler2D u_image0; + +in vec2 v_texCoord; +layout(location = 0) out vec4 fragColor0; +layout(location = 1) out vec4 fragColor1; +layout(location = 2) out vec4 fragColor2; +layout(location = 3) out vec4 fragColor3; + +void main() { + vec4 color = texture(u_image0, v_texCoord); + // Output each channel as grayscale to separate render targets + fragColor0 = vec4(vec3(color.r), 1.0); // Red channel + fragColor1 = vec4(vec3(color.g), 1.0); // Green channel + fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel + fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel +} diff --git a/blueprints/.glsl/Image_Levels_1.frag b/blueprints/.glsl/Image_Levels_1.frag new file mode 100644 index 000000000..f34ed1d81 --- /dev/null +++ b/blueprints/.glsl/Image_Levels_1.frag @@ -0,0 +1,71 @@ +#version 300 es +precision highp float; + +// Levels Adjustment +// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0 +// u_float0: input black (0-255) default: 0 +// u_float1: input white (0-255) default: 255 +// u_float2: gamma (0.01-9.99) default: 1.0 +// u_float3: output black (0-255) default: 0 +// u_float4: output white (0-255) default: 255 + +uniform sampler2D u_image0; +uniform int u_int0; +uniform float u_float0; +uniform float u_float1; +uniform float u_float2; +uniform float u_float3; +uniform float u_float4; + +in vec2 v_texCoord; +out vec4 fragColor; + +vec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) { + float inRange = max(inWhite - inBlack, 0.0001); + vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0); + result = pow(result, vec3(1.0 / gamma)); + result = mix(vec3(outBlack), vec3(outWhite), result); + return result; +} + +float applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) { + float inRange = max(inWhite - inBlack, 0.0001); + float result = clamp((value - inBlack) / inRange, 0.0, 1.0); + result = pow(result, 1.0 / gamma); + result = mix(outBlack, outWhite, result); + return result; +} + +void main() { + vec4 texColor = texture(u_image0, v_texCoord); + vec3 color = texColor.rgb; + + float inBlack = u_float0 / 255.0; + float inWhite = u_float1 / 255.0; + float gamma = u_float2; + float outBlack = u_float3 / 255.0; + float outWhite = u_float4 / 255.0; + + vec3 result; + + if (u_int0 == 0) { + result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite); + } + else if (u_int0 == 1) { + result = color; + result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite); + } + else if (u_int0 == 2) { + result = color; + result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite); + } + else if (u_int0 == 3) { + result = color; + result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite); + } + else { + result = color; + } + + fragColor = vec4(result, texColor.a); +} \ No newline at end of file diff --git a/blueprints/.glsl/README.md b/blueprints/.glsl/README.md new file mode 100644 index 000000000..d4084284b --- /dev/null +++ b/blueprints/.glsl/README.md @@ -0,0 +1,28 @@ +# GLSL Shader Sources + +This folder contains the GLSL fragment shaders extracted from blueprint JSON files for easier editing and version control. + +## File Naming Convention + +`{Blueprint_Name}_{node_id}.frag` + +- **Blueprint_Name**: The JSON filename with spaces/special chars replaced by underscores +- **node_id**: The GLSLShader node ID within the subgraph + +## Usage + +```bash +# Extract shaders from blueprint JSONs to this folder +python update_blueprints.py extract + +# Patch edited shaders back into blueprint JSONs +python update_blueprints.py patch +``` + +## Workflow + +1. Run `extract` to pull current shaders from JSONs +2. Edit `.frag` files +3. Run `patch` to update the blueprint JSONs +4. Test +5. Commit both `.frag` files and updated JSONs diff --git a/blueprints/.glsl/Sharpen_23.frag b/blueprints/.glsl/Sharpen_23.frag new file mode 100644 index 000000000..c03f94b66 --- /dev/null +++ b/blueprints/.glsl/Sharpen_23.frag @@ -0,0 +1,28 @@ +#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform vec2 u_resolution; +uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0 + +in vec2 v_texCoord; +layout(location = 0) out vec4 fragColor0; + +void main() { + vec2 texel = 1.0 / u_resolution; + + // Sample center and neighbors + vec4 center = texture(u_image0, v_texCoord); + vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y)); + vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y)); + vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0)); + vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0)); + + // Edge enhancement (Laplacian) + vec4 edges = center * 4.0 - top - bottom - left - right; + + // Add edges back scaled by strength + vec4 sharpened = center + edges * u_float0; + + fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a); +} \ No newline at end of file diff --git a/blueprints/.glsl/Unsharp_Mask_26.frag b/blueprints/.glsl/Unsharp_Mask_26.frag new file mode 100644 index 000000000..f5990cb4a --- /dev/null +++ b/blueprints/.glsl/Unsharp_Mask_26.frag @@ -0,0 +1,61 @@ +#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform vec2 u_resolution; +uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5 +uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels +uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen + +in vec2 v_texCoord; +layout(location = 0) out vec4 fragColor0; + +float gaussian(float x, float sigma) { + return exp(-(x * x) / (2.0 * sigma * sigma)); +} + +float getLuminance(vec3 color) { + return dot(color, vec3(0.2126, 0.7152, 0.0722)); +} + +void main() { + vec2 texel = 1.0 / u_resolution; + float radius = max(u_float1, 0.5); + float amount = u_float0; + float threshold = u_float2; + + vec4 original = texture(u_image0, v_texCoord); + + // Gaussian blur for the "unsharp" mask + int samples = int(ceil(radius)); + float sigma = radius / 2.0; + + vec4 blurred = vec4(0.0); + float totalWeight = 0.0; + + for (int x = -samples; x <= samples; x++) { + for (int y = -samples; y <= samples; y++) { + vec2 offset = vec2(float(x), float(y)) * texel; + vec4 sample_color = texture(u_image0, v_texCoord + offset); + + float dist = length(vec2(float(x), float(y))); + float weight = gaussian(dist, sigma); + blurred += sample_color * weight; + totalWeight += weight; + } + } + blurred /= totalWeight; + + // Unsharp mask = original - blurred + vec3 mask = original.rgb - blurred.rgb; + + // Luminance-based threshold with smooth falloff + float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb)); + float thresholdScale = smoothstep(0.0, threshold, lumaDelta); + mask *= thresholdScale; + + // Sharpen: original + mask * amount + vec3 sharpened = original.rgb + mask * amount; + + fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a); +} diff --git a/blueprints/.glsl/update_blueprints.py b/blueprints/.glsl/update_blueprints.py new file mode 100644 index 000000000..c5bd0ed54 --- /dev/null +++ b/blueprints/.glsl/update_blueprints.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +""" +Shader Blueprint Updater + +Syncs GLSL shader files between this folder and blueprint JSON files. + +File naming convention: + {Blueprint Name}_{node_id}.frag + +Usage: + python update_blueprints.py extract # Extract shaders from JSONs to here + python update_blueprints.py patch # Patch shaders back into JSONs + python update_blueprints.py # Same as patch (default) +""" + +import json +import logging +import sys +import re +from pathlib import Path + +logging.basicConfig(level=logging.INFO, format='%(message)s') +logger = logging.getLogger(__name__) + +GLSL_DIR = Path(__file__).parent +BLUEPRINTS_DIR = GLSL_DIR.parent + + +def get_blueprint_files(): + """Get all blueprint JSON files.""" + return sorted(BLUEPRINTS_DIR.glob("*.json")) + + +def sanitize_filename(name): + """Convert blueprint name to safe filename.""" + return re.sub(r'[^\w\-]', '_', name) + + +def extract_shaders(): + """Extract all shaders from blueprint JSONs to this folder.""" + extracted = 0 + for json_path in get_blueprint_files(): + blueprint_name = json_path.stem + + try: + with open(json_path, 'r') as f: + data = json.load(f) + except (json.JSONDecodeError, IOError) as e: + logger.warning("Skipping %s: %s", json_path.name, e) + continue + + # Find GLSLShader nodes in subgraphs + for subgraph in data.get('definitions', {}).get('subgraphs', []): + for node in subgraph.get('nodes', []): + if node.get('type') == 'GLSLShader': + node_id = node.get('id') + widgets = node.get('widgets_values', []) + + # Find shader code (first string that looks like GLSL) + for widget in widgets: + if isinstance(widget, str) and widget.startswith('#version'): + safe_name = sanitize_filename(blueprint_name) + frag_name = f"{safe_name}_{node_id}.frag" + frag_path = GLSL_DIR / frag_name + + with open(frag_path, 'w') as f: + f.write(widget) + + logger.info(" Extracted: %s", frag_name) + extracted += 1 + break + + logger.info("\nExtracted %d shader(s)", extracted) + + +def patch_shaders(): + """Patch shaders from this folder back into blueprint JSONs.""" + # Build lookup: blueprint_name -> [(node_id, shader_code), ...] + shader_updates = {} + + for frag_path in sorted(GLSL_DIR.glob("*.frag")): + # Parse filename: {blueprint_name}_{node_id}.frag + parts = frag_path.stem.rsplit('_', 1) + if len(parts) != 2: + logger.warning("Skipping %s: invalid filename format", frag_path.name) + continue + + blueprint_name, node_id_str = parts + + try: + node_id = int(node_id_str) + except ValueError: + logger.warning("Skipping %s: invalid node_id", frag_path.name) + continue + + with open(frag_path, 'r') as f: + shader_code = f.read() + + if blueprint_name not in shader_updates: + shader_updates[blueprint_name] = [] + shader_updates[blueprint_name].append((node_id, shader_code)) + + # Apply updates to JSON files + patched = 0 + for json_path in get_blueprint_files(): + blueprint_name = sanitize_filename(json_path.stem) + + if blueprint_name not in shader_updates: + continue + + try: + with open(json_path, 'r') as f: + data = json.load(f) + except (json.JSONDecodeError, IOError) as e: + logger.error("Error reading %s: %s", json_path.name, e) + continue + + modified = False + for node_id, shader_code in shader_updates[blueprint_name]: + # Find the node and update + for subgraph in data.get('definitions', {}).get('subgraphs', []): + for node in subgraph.get('nodes', []): + if node.get('id') == node_id and node.get('type') == 'GLSLShader': + widgets = node.get('widgets_values', []) + if len(widgets) > 0 and widgets[0] != shader_code: + widgets[0] = shader_code + modified = True + logger.info(" Patched: %s (node %d)", json_path.name, node_id) + patched += 1 + + if modified: + with open(json_path, 'w') as f: + json.dump(data, f) + + if patched == 0: + logger.info("No changes to apply.") + else: + logger.info("\nPatched %d shader(s)", patched) + + +def main(): + if len(sys.argv) < 2: + command = "patch" + else: + command = sys.argv[1].lower() + + if command == "extract": + logger.info("Extracting shaders from blueprints...") + extract_shaders() + elif command in ("patch", "update", "apply"): + logger.info("Patching shaders into blueprints...") + patch_shaders() + else: + logger.info(__doc__) + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/blueprints/Brightness and Contrast.json b/blueprints/Brightness and Contrast.json new file mode 100644 index 000000000..2c7e60eb1 --- /dev/null +++ b/blueprints/Brightness and Contrast.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 140, "last_link_id": 0, "nodes": [{"id": 140, "type": "916dff42-6166-4d45-b028-04eaf69fbb35", "pos": [500, 1440], "size": [250, 178], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["4", "value"], ["5", "value"]]}, "widgets_values": [], "title": "Brightness and Contrast"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "916dff42-6166-4d45-b028-04eaf69fbb35", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 143, "lastLinkId": 118, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Brightness and Contrast", "inputNode": {"id": -10, "bounding": [360, -176, 120, 60]}, "outputNode": {"id": -20, "bounding": [1410, -176, 120, 60]}, "inputs": [{"id": "a5aae7ea-b511-4045-b5da-94101e269cd7", "name": "images.image0", "type": "IMAGE", "linkIds": [117], "localized_name": "images.image0", "label": "image", "pos": [460, -156]}], "outputs": [{"id": "30b72604-69b3-4944-b253-a9099bbd73a9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [118], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [1430, -156]}], "widgets": [], "nodes": [{"id": 4, "type": "PrimitiveFloat", "pos": [540, -280], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "brightness", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [115]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 100, "precision": 1, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 0, 0]}, {"offset": 1, "color": [255, 255, 255]}]}, "widgets_values": [50]}, {"id": 5, "type": "PrimitiveFloat", "pos": [540, -170], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "contrast", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [116]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 100, "precision": 1, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [136, 136, 136]}, {"offset": 0.4, "color": [68, 68, 68]}, {"offset": 0.6, "color": [187, 187, 187]}, {"offset": 0.8, "color": [0, 0, 0]}, {"offset": 1, "color": [255, 255, 255]}]}, "widgets_values": [0]}, {"id": 143, "type": "GLSLShader", "pos": [840, -280], "size": [400, 212], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 117}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 115}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 116}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [118]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // Brightness slider -100..100\nuniform float u_float1; // Contrast slider -100..100\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst float MID_GRAY = 0.18; // 18% reflectance\n\n// sRGB gamma 2.2 approximation\nvec3 srgbToLinear(vec3 c) {\n return pow(max(c, 0.0), vec3(2.2));\n}\n\nvec3 linearToSrgb(vec3 c) {\n return pow(max(c, 0.0), vec3(1.0/2.2));\n}\n\nfloat mapBrightness(float b) {\n return clamp(b / 100.0, -1.0, 1.0);\n}\n\nfloat mapContrast(float c) {\n return clamp(c / 100.0 + 1.0, 0.0, 2.0);\n}\n\nvoid main() {\n vec4 orig = texture(u_image0, v_texCoord);\n\n float brightness = mapBrightness(u_float0);\n float contrast = mapContrast(u_float1);\n\n vec3 lin = srgbToLinear(orig.rgb);\n\n lin = (lin - MID_GRAY) * contrast + brightness + MID_GRAY;\n\n // Convert back to sRGB\n vec3 result = linearToSrgb(clamp(lin, 0.0, 1.0));\n\n fragColor = vec4(result, orig.a);\n}\n", "from_input"]}], "groups": [], "links": [{"id": 115, "origin_id": 4, "origin_slot": 0, "target_id": 143, "target_slot": 2, "type": "FLOAT"}, {"id": 116, "origin_id": 5, "origin_slot": 0, "target_id": 143, "target_slot": 3, "type": "FLOAT"}, {"id": 117, "origin_id": -10, "origin_slot": 0, "target_id": 143, "target_slot": 0, "type": "IMAGE"}, {"id": 118, "origin_id": 143, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}, "extra": {}} diff --git a/blueprints/Canny to Image (Z-Image-Turbo).json b/blueprints/Canny to Image (Z-Image-Turbo).json new file mode 100644 index 000000000..8b78a834a --- /dev/null +++ b/blueprints/Canny to Image (Z-Image-Turbo).json @@ -0,0 +1 @@ +{"id": "e046dd74-e2a7-4f31-a75b-5e11a8c72d4e", "revision": 0, "last_node_id": 18, "last_link_id": 32, "nodes": [{"id": 18, "type": "c84f7959-3738-422b-ba6e-5808b5e90101", "pos": [300, 3830], "size": [400, 460], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "control image", "name": "image", "type": "IMAGE", "link": null}, {"label": "prompt", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"label": "canny low threshold", "name": "low_threshold", "type": "FLOAT", "widget": {"name": "low_threshold"}, "link": null}, {"label": "canny high threshold", "name": "high_threshold", "type": "FLOAT", "widget": {"name": "high_threshold"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}, {"name": "name", "type": "COMBO", "widget": {"name": "name"}, "link": null}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": null}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "low_threshold"], ["-1", "high_threshold"], ["7", "seed"], ["7", "control_after_generate"], ["-1", "unet_name"], ["-1", "clip_name"], ["-1", "vae_name"], ["-1", "name"]], "cnr_id": "comfy-core", "ver": "0.11.0"}, "widgets_values": ["", 0.3, 0.4, null, null, "z_image_turbo_bf16.safetensors", "qwen_3_4b.safetensors", "ae.safetensors", "Z-Image-Turbo-Fun-Controlnet-Union.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "c84f7959-3738-422b-ba6e-5808b5e90101", "version": 1, "state": {"lastGroupId": 3, "lastNodeId": 18, "lastLinkId": 32, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Canny to Image (Z-Image-Turbo)", "inputNode": {"id": -10, "bounding": [-280, 4960, 158.880859375, 200]}, "outputNode": {"id": -20, "bounding": [1598.6038576146689, 4936.043696127976, 120, 60]}, "inputs": [{"id": "29ca271b-8f63-4e7b-a4b8-c9b4192ada0b", "name": "image", "type": "IMAGE", "linkIds": [26], "label": "control image", "pos": [-141.119140625, 4980]}, {"id": "b6549f90-39ee-4b79-9e00-af4d9df969fe", "name": "text", "type": "STRING", "linkIds": [16], "label": "prompt", "pos": [-141.119140625, 5000]}, {"id": "6bd34d18-79f6-470f-94df-ca14c84ef3d8", "name": "low_threshold", "type": "FLOAT", "linkIds": [24], "label": "canny low threshold", "pos": [-141.119140625, 5020]}, {"id": "bbced993-057f-4d2d-909c-d791be73d1d2", "name": "high_threshold", "type": "FLOAT", "linkIds": [25], "label": "canny high threshold", "pos": [-141.119140625, 5040]}, {"id": "db7969bf-4b05-48a0-9598-87d3ac85b505", "name": "unet_name", "type": "COMBO", "linkIds": [29], "pos": [-141.119140625, 5060]}, {"id": "925b611c-5edf-406f-8dc5-7fec07d049a7", "name": "clip_name", "type": "COMBO", "linkIds": [30], "pos": [-141.119140625, 5080]}, {"id": "b4cf508b-4753-40d2-8c83-5a424237ee07", "name": "vae_name", "type": "COMBO", "linkIds": [31], "pos": [-141.119140625, 5100]}, {"id": "bd948f38-3a11-4091-99fc-bb2b3511bcd2", "name": "name", "type": "COMBO", "linkIds": [32], "pos": [-141.119140625, 5120]}], "outputs": [{"id": "47f9a22d-6619-4917-9447-a7d5d08dceb5", "name": "IMAGE", "type": "IMAGE", "linkIds": [18], "pos": [1618.6038576146689, 4956.043696127976]}], "widgets": [], "nodes": [{"id": 1, "type": "CLIPLoader", "pos": [228.60376290329597, 4700.188357350136], "size": [270, 106], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 30}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [14]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPLoader", "models": [{"name": "qwen_3_4b.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/text_encoders/qwen_3_4b.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_3_4b.safetensors", "lumina2", "default"]}, {"id": 2, "type": "UNETLoader", "pos": [228.60376290329597, 4550.1883046176445], "size": [270, 82], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 29}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [9]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "UNETLoader", "models": [{"name": "z_image_turbo_bf16.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/diffusion_models/z_image_turbo_bf16.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["z_image_turbo_bf16.safetensors", "default"]}, {"id": 3, "type": "VAELoader", "pos": [228.60376290329597, 4880.18831633181], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 31}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [2, 11]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "VAELoader", "models": [{"name": "ae.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/vae/ae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ae.safetensors"]}, {"id": 4, "type": "ModelPatchLoader", "pos": [228.60376290329597, 5010.1884895078], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "name", "name": "name", "type": "COMBO", "widget": {"name": "name"}, "link": 32}], "outputs": [{"localized_name": "MODEL_PATCH", "name": "MODEL_PATCH", "type": "MODEL_PATCH", "links": [10]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "ModelPatchLoader", "models": [{"name": "Z-Image-Turbo-Fun-Controlnet-Union.safetensors", "url": "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union.safetensors", "directory": "model_patches"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["Z-Image-Turbo-Fun-Controlnet-Union.safetensors"]}, {"id": 6, "type": "ModelSamplingAuraFlow", "pos": [998.6039930366841, 4490.18831829042], "size": [290, 58], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 3}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [4]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "ModelSamplingAuraFlow", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3]}, {"id": 7, "type": "KSampler", "pos": [998.6039930366841, 4600.188351166619], "size": [300, 460], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 4}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 5}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 6}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 7}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "KSampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize", 9, 1, "res_multistep", "simple", 1]}, {"id": 8, "type": "ConditioningZeroOut", "pos": [748.2704434516113, 5044.855005348689], "size": [204.134765625, 26.000000000000004], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 8}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [6]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "ConditioningZeroOut", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 9, "type": "QwenImageDiffsynthControlnet", "pos": [608.2704174118008, 5204.85499785943], "size": [290, 138], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 9}, {"localized_name": "model_patch", "name": "model_patch", "type": "MODEL_PATCH", "link": 10}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 11}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 22}, {"localized_name": "mask", "name": "mask", "shape": 7, "type": "MASK", "link": null}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [3]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.76", "Node name for S&R": "QwenImageDiffsynthControlnet", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1]}, {"id": 12, "type": "CLIPTextEncode", "pos": [548.2704310845766, 4544.854974431101], "size": [400, 330], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 14}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 16}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [5, 8]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 5, "type": "VAEDecode", "pos": [1338.6038576146689, 4500.188344983101], "size": [200, 46], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 1}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 2}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [18]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 15, "type": "ImageScaleToTotalPixels", "pos": [220, 5220], "size": [270, 106], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 26}, {"localized_name": "upscale_method", "name": "upscale_method", "type": "COMBO", "widget": {"name": "upscale_method"}, "link": null}, {"localized_name": "megapixels", "name": "megapixels", "type": "FLOAT", "widget": {"name": "megapixels"}, "link": null}, {"localized_name": "resolution_steps", "name": "resolution_steps", "type": "INT", "widget": {"name": "resolution_steps"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [27]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.0", "Node name for S&R": "ImageScaleToTotalPixels"}, "widgets_values": ["nearest-exact", 1, 1]}, {"id": 11, "type": "GetImageSize", "pos": [540, 5450], "size": [140, 66], "flags": {"collapsed": false}, "order": 10, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 23}], "outputs": [{"localized_name": "width", "name": "width", "type": "INT", "links": [12]}, {"localized_name": "height", "name": "height", "type": "INT", "links": [13]}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "links": null}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.76", "Node name for S&R": "GetImageSize", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 10, "type": "EmptySD3LatentImage", "pos": [760, 5430], "size": [260, 106], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 12}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 13}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [7]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "EmptySD3LatentImage", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1024, 1024, 1]}, {"id": 14, "type": "Canny", "pos": [220, 5380], "size": [270, 82], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 27}, {"localized_name": "low_threshold", "name": "low_threshold", "type": "FLOAT", "widget": {"name": "low_threshold"}, "link": 24}, {"localized_name": "high_threshold", "name": "high_threshold", "type": "FLOAT", "widget": {"name": "high_threshold"}, "link": 25}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [22, 23, 28]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.0", "Node name for S&R": "Canny"}, "widgets_values": [0.3, 0.4]}, {"id": 16, "type": "PreviewImage", "pos": [220, 5520], "size": [260, 270], "flags": {}, "order": 14, "mode": 4, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 28}], "outputs": [], "properties": {"cnr_id": "comfy-core", "ver": "0.11.0", "Node name for S&R": "PreviewImage"}, "widgets_values": []}], "groups": [{"id": 1, "title": "Prompt", "bounding": [530, 4460, 440, 630], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Models", "bounding": [210, 4460, 300, 640], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Apple ControlNet", "bounding": [530, 5120, 440, 260], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 1, "origin_id": 7, "origin_slot": 0, "target_id": 5, "target_slot": 0, "type": "LATENT"}, {"id": 2, "origin_id": 3, "origin_slot": 0, "target_id": 5, "target_slot": 1, "type": "VAE"}, {"id": 3, "origin_id": 9, "origin_slot": 0, "target_id": 6, "target_slot": 0, "type": "MODEL"}, {"id": 4, "origin_id": 6, "origin_slot": 0, "target_id": 7, "target_slot": 0, "type": "MODEL"}, {"id": 5, "origin_id": 12, "origin_slot": 0, "target_id": 7, "target_slot": 1, "type": "CONDITIONING"}, {"id": 6, "origin_id": 8, "origin_slot": 0, "target_id": 7, "target_slot": 2, "type": "CONDITIONING"}, {"id": 7, "origin_id": 10, "origin_slot": 0, "target_id": 7, "target_slot": 3, "type": "LATENT"}, {"id": 8, "origin_id": 12, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "CONDITIONING"}, {"id": 9, "origin_id": 2, "origin_slot": 0, "target_id": 9, "target_slot": 0, "type": "MODEL"}, {"id": 10, "origin_id": 4, "origin_slot": 0, "target_id": 9, "target_slot": 1, "type": "MODEL_PATCH"}, {"id": 11, "origin_id": 3, "origin_slot": 0, "target_id": 9, "target_slot": 2, "type": "VAE"}, {"id": 12, "origin_id": 11, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "INT"}, {"id": 13, "origin_id": 11, "origin_slot": 1, "target_id": 10, "target_slot": 1, "type": "INT"}, {"id": 14, "origin_id": 1, "origin_slot": 0, "target_id": 12, "target_slot": 0, "type": "CLIP"}, {"id": 16, "origin_id": -10, "origin_slot": 1, "target_id": 12, "target_slot": 1, "type": "STRING"}, {"id": 18, "origin_id": 5, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 22, "origin_id": 14, "origin_slot": 0, "target_id": 9, "target_slot": 3, "type": "IMAGE"}, {"id": 23, "origin_id": 14, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 24, "origin_id": -10, "origin_slot": 2, "target_id": 14, "target_slot": 1, "type": "FLOAT"}, {"id": 25, "origin_id": -10, "origin_slot": 3, "target_id": 14, "target_slot": 2, "type": "FLOAT"}, {"id": 26, "origin_id": -10, "origin_slot": 0, "target_id": 15, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 15, "origin_slot": 0, "target_id": 14, "target_slot": 0, "type": "IMAGE"}, {"id": 28, "origin_id": 14, "origin_slot": 0, "target_id": 16, "target_slot": 0, "type": "IMAGE"}, {"id": 29, "origin_id": -10, "origin_slot": 4, "target_id": 2, "target_slot": 0, "type": "COMBO"}, {"id": 30, "origin_id": -10, "origin_slot": 5, "target_id": 1, "target_slot": 0, "type": "COMBO"}, {"id": 31, "origin_id": -10, "origin_slot": 6, "target_id": 3, "target_slot": 0, "type": "COMBO"}, {"id": 32, "origin_id": -10, "origin_slot": 7, "target_id": 4, "target_slot": 0, "type": "COMBO"}], "extra": {"frontendVersion": "1.37.10", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true}, "category": "Image generation and editing/Canny to image"}]}, "config": {}, "extra": {"frontendVersion": "1.37.10", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true, "ds": {"scale": 0.967267584583181, "offset": [444.759060017523, -3564.372163194443]}}, "version": 0.4} diff --git a/blueprints/Canny to Video (LTX 2.0).json b/blueprints/Canny to Video (LTX 2.0).json new file mode 100644 index 000000000..cd2c4e594 --- /dev/null +++ b/blueprints/Canny to Video (LTX 2.0).json @@ -0,0 +1 @@ +{"id": "02f6166f-32f8-4673-b861-76be1464cba5", "revision": 0, "last_node_id": 155, "last_link_id": 391, "nodes": [{"id": 1, "type": "884e1862-7567-4e72-bd2a-fd4fdfd06320", "pos": [1519.643633934233, 3717.5350173634242], "size": [400, 500], "flags": {"collapsed": false}, "order": 0, "mode": 0, "inputs": [{"name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"label": "canny_images", "name": "image", "type": "IMAGE", "link": null}, {"label": "image_strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}, {"label": "disable_first_frame", "name": "bypass", "type": "BOOLEAN", "widget": {"name": "bypass"}, "link": null}, {"label": "first_frame", "name": "image_1", "type": "IMAGE", "link": null}, {"name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": null}, {"name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": null}, {"name": "text_encoder", "type": "COMBO", "widget": {"name": "text_encoder"}, "link": null}, {"label": "distlled_lora", "name": "lora_name_1", "type": "COMBO", "widget": {"name": "lora_name_1"}, "link": null}, {"label": "upscale_model", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "strength"], ["143", "noise_seed"], ["126", "control_after_generate"], ["-1", "bypass"], ["-1", "ckpt_name"], ["-1", "lora_name"], ["-1", "text_encoder"], ["-1", "lora_name_1"], ["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.7.0", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["", 1, null, null, false, "ltx-2-19b-dev-fp8.safetensors", "ltx-2-19b-ic-lora-canny-control.safetensors", "gemma_3_12B_it_fp4_mixed.safetensors", "ltx-2-19b-distilled-lora-384.safetensors", "ltx-2-spatial-upscaler-x2-1.0.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "884e1862-7567-4e72-bd2a-fd4fdfd06320", "version": 1, "state": {"lastGroupId": 11, "lastNodeId": 155, "lastLinkId": 391, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Canny to Video (LTX 2.0)", "inputNode": {"id": -10, "bounding": [-2180, 4070, 146.8515625, 240]}, "outputNode": {"id": -20, "bounding": [1750, 4090, 120, 60]}, "inputs": [{"id": "0f1d2f96-933a-4a7b-8f1a-7b49fc4ade09", "name": "text", "type": "STRING", "linkIds": [345], "pos": [-2053.1484375, 4090]}, {"id": "35a07084-3ecf-482a-a330-b40278770ca3", "name": "image", "type": "IMAGE", "linkIds": [348, 349], "label": "canny_images", "pos": [-2053.1484375, 4110]}, {"id": "59430efe-1090-4e36-8afe-b21ce7f4268b", "name": "strength", "type": "FLOAT", "linkIds": [370, 371], "label": "image_strength", "pos": [-2053.1484375, 4130]}, {"id": "6145a9b9-68ed-4956-89f7-7a5ebdd5c99e", "name": "bypass", "type": "BOOLEAN", "linkIds": [363, 368], "label": "disable_first_frame", "pos": [-2053.1484375, 4150]}, {"id": "bea20802-d654-4287-a8ef-0f834314bcf9", "name": "image_1", "type": "IMAGE", "linkIds": [364, 379], "label": "first_frame", "pos": [-2053.1484375, 4170]}, {"id": "4e2f26b5-9ad6-49a6-8e90-0ed24fc6a423", "name": "ckpt_name", "type": "COMBO", "linkIds": [385, 386, 387], "pos": [-2053.1484375, 4190]}, {"id": "81fdfcf3-92ca-4f8d-b13d-d22758231530", "name": "lora_name", "type": "COMBO", "linkIds": [388], "pos": [-2053.1484375, 4210]}, {"id": "3fa7991e-4419-44a7-9377-1b6125fef355", "name": "text_encoder", "type": "COMBO", "linkIds": [389], "pos": [-2053.1484375, 4230]}, {"id": "b9277d33-2f18-47bb-95ab-666799e8730f", "name": "lora_name_1", "type": "COMBO", "linkIds": [390], "label": "distlled_lora", "pos": [-2053.1484375, 4250]}, {"id": "80b2e9cf-e1a7-462f-ae0d-ffb4ba668a65", "name": "model_name", "type": "COMBO", "linkIds": [391], "label": "upscale_model", "pos": [-2053.1484375, 4270]}], "outputs": [{"id": "4e837941-de2d-4df8-8f94-686e24036897", "name": "VIDEO", "type": "VIDEO", "linkIds": [304], "localized_name": "VIDEO", "pos": [1770, 4110]}], "widgets": [], "nodes": [{"id": 93, "type": "CFGGuider", "pos": [-698, 3670], "size": [270, 106.66666666666667], "flags": {}, "order": 16, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 326}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 309}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 311}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "links": [261]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "CFGGuider", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3]}, {"id": 94, "type": "KSamplerSelect", "pos": [-698, 3840], "size": [270, 68.88020833333334], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "links": [262]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "KSamplerSelect", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["euler"]}, {"id": 99, "type": "ManualSigmas", "pos": [410, 3850], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "sigmas", "name": "sigmas", "type": "STRING", "widget": {"name": "sigmas"}, "link": null}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "links": [278]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "ManualSigmas", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["0.909375, 0.725, 0.421875, 0.0"]}, {"id": 101, "type": "LTXVConcatAVLatent", "pos": [410, 4100], "size": [270, 110], "flags": {}, "order": 18, "mode": 0, "inputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "link": 365}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "link": 266}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [279]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "LTXVConcatAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 108, "type": "CFGGuider", "pos": [410, 3700], "size": [270, 98], "flags": {}, "order": 22, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 280}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 281}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 282}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "links": [276]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.71", "Node name for S&R": "CFGGuider", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1]}, {"id": 111, "type": "LTXVEmptyLatentAudio", "pos": [-1100, 4810], "size": [270, 120], "flags": {}, "order": 24, "mode": 0, "inputs": [{"localized_name": "audio_vae", "name": "audio_vae", "type": "VAE", "link": 383}, {"localized_name": "frames_number", "name": "frames_number", "type": "INT", "widget": {"name": "frames_number"}, "link": 329}, {"localized_name": "frame_rate", "name": "frame_rate", "type": "INT", "widget": {"name": "frame_rate"}, "link": 354}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "Latent", "name": "Latent", "type": "LATENT", "links": [300]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.68", "Node name for S&R": "LTXVEmptyLatentAudio", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [97, 25, 1]}, {"id": 123, "type": "SamplerCustomAdvanced", "pos": [-388, 3520], "size": [213.125, 120], "flags": {}, "order": 31, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 260}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 261}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 262}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 263}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 323}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "links": [272]}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.60", "Node name for S&R": "SamplerCustomAdvanced", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 114, "type": "LTXVConditioning", "pos": [-1134, 4140], "size": [270, 86.66666666666667], "flags": {}, "order": 27, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 292}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 293}, {"localized_name": "frame_rate", "name": "frame_rate", "type": "FLOAT", "widget": {"name": "frame_rate"}, "link": 355}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [313]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [314]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "LTXVConditioning", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [25]}, {"id": 119, "type": "CLIPTextEncode", "pos": [-1164, 3880], "size": [400, 200], "flags": {}, "order": 14, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 294}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [293]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["blurry, low quality, still frame, frames, watermark, overlay, titles, has blurbox, has subtitles"], "color": "#323", "bgcolor": "#535"}, {"id": 116, "type": "LTXVConcatAVLatent", "pos": [-520, 4700], "size": [187.5, 60], "flags": {}, "order": 29, "mode": 0, "inputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "link": 324}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "link": 300}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [322, 323]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVConcatAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 122, "type": "LTXVSeparateAVLatent", "pos": [-394, 3800], "size": [240, 46], "flags": {}, "order": 30, "mode": 0, "inputs": [{"localized_name": "av_latent", "name": "av_latent", "type": "LATENT", "link": 272}], "outputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "links": [270]}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "links": [266]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "LTXVSeparateAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 124, "type": "CLIPTextEncode", "pos": [-1174.999849798713, 3514.000055195033], "size": [410, 320], "flags": {}, "order": 32, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 295}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 345}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [292]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 98, "type": "KSamplerSelect", "pos": [410, 3980], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "links": [277]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "KSamplerSelect", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["gradient_estimation"]}, {"id": 95, "type": "LTXVScheduler", "pos": [-700, 3980], "size": [270, 170], "flags": {}, "order": 17, "mode": 0, "inputs": [{"localized_name": "latent", "name": "latent", "shape": 7, "type": "LATENT", "link": 322}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "max_shift", "name": "max_shift", "type": "FLOAT", "widget": {"name": "max_shift"}, "link": null}, {"localized_name": "base_shift", "name": "base_shift", "type": "FLOAT", "widget": {"name": "base_shift"}, "link": null}, {"localized_name": "stretch", "name": "stretch", "type": "BOOLEAN", "widget": {"name": "stretch"}, "link": null}, {"localized_name": "terminal", "name": "terminal", "type": "FLOAT", "widget": {"name": "terminal"}, "link": null}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "links": [263]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "LTXVScheduler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [20, 2.05, 0.95, true, 0.1]}, {"id": 126, "type": "RandomNoise", "pos": [-698, 3520], "size": [270, 82], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "links": [260]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "RandomNoise", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize"]}, {"id": 107, "type": "SamplerCustomAdvanced", "pos": [710, 3570], "size": [212.38333740234376, 106], "flags": {}, "order": 21, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 347}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 276}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 277}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 278}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 279}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "links": []}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "links": [336]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "SamplerCustomAdvanced", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 143, "type": "RandomNoise", "pos": [410, 3570], "size": [270, 82], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "links": [347]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "RandomNoise", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "fixed"]}, {"id": 139, "type": "LTXVAudioVAEDecode", "pos": [1130, 3840], "size": [240, 46], "flags": {}, "order": 35, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 338}, {"label": "Audio VAE", "localized_name": "audio_vae", "name": "audio_vae", "type": "VAE", "link": 384}], "outputs": [{"localized_name": "Audio", "name": "Audio", "type": "AUDIO", "links": [339]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVAudioVAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 106, "type": "CreateVideo", "pos": [1420, 3760], "size": [270, 78], "flags": {}, "order": 20, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 352}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 339}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 356}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [304]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "CreateVideo", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [25]}, {"id": 134, "type": "LoraLoaderModelOnly", "pos": [-1650, 3760], "size": [420, 82], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 325}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 388}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [326, 327]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "ltx-2-19b-ic-lora-canny-control.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Canny-Control/resolve/main/ltx-2-19b-ic-lora-canny-control.safetensors", "directory": "loras"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-ic-lora-canny-control.safetensors", 1], "color": "#322", "bgcolor": "#533"}, {"id": 138, "type": "LTXVSeparateAVLatent", "pos": [730, 3730], "size": [193.2916015625, 46], "flags": {}, "order": 34, "mode": 0, "inputs": [{"localized_name": "av_latent", "name": "av_latent", "type": "LATENT", "link": 336}], "outputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "links": [337, 351]}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "links": [338]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "LTXVSeparateAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 144, "type": "VAEDecodeTiled", "pos": [1120, 3640], "size": [270, 150], "flags": {}, "order": 36, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 351}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 353}, {"localized_name": "tile_size", "name": "tile_size", "type": "INT", "widget": {"name": "tile_size"}, "link": null}, {"localized_name": "overlap", "name": "overlap", "type": "INT", "widget": {"name": "overlap"}, "link": null}, {"localized_name": "temporal_size", "name": "temporal_size", "type": "INT", "widget": {"name": "temporal_size"}, "link": null}, {"localized_name": "temporal_overlap", "name": "temporal_overlap", "type": "INT", "widget": {"name": "temporal_overlap"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [352]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "VAEDecodeTiled", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [512, 64, 4096, 8]}, {"id": 113, "type": "VAEDecode", "pos": [1130, 3530], "size": [240, 50], "flags": {}, "order": 26, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 337}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 291}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 110, "type": "GetImageSize", "pos": [-1630, 4450], "size": [260, 80], "flags": {}, "order": 23, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 349}], "outputs": [{"localized_name": "width", "name": "width", "type": "INT", "links": [296]}, {"localized_name": "height", "name": "height", "type": "INT", "links": [297]}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "links": [329, 330]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "GetImageSize", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 145, "type": "PrimitiveInt", "pos": [-1630, 4620], "size": [270, 82], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "value", "name": "value", "type": "INT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "INT", "name": "INT", "type": "INT", "links": [354]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "PrimitiveInt", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [24, "fixed"]}, {"id": 148, "type": "PrimitiveFloat", "pos": [-1630, 4750], "size": [270, 58], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [355, 356]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "PrimitiveFloat", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [24]}, {"id": 115, "type": "EmptyLTXVLatentVideo", "pos": [-1100, 4610], "size": [270, 146.66666666666669], "flags": {}, "order": 28, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 296}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 297}, {"localized_name": "length", "name": "length", "type": "INT", "widget": {"name": "length"}, "link": 330}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [360]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.60", "Node name for S&R": "EmptyLTXVLatentVideo", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [768, 512, 97, 1]}, {"id": 149, "type": "LTXVImgToVideoInplace", "pos": [-1090, 4400], "size": [270, 152], "flags": {}, "order": 37, "mode": 0, "inputs": [{"localized_name": "vae", "name": "vae", "type": "VAE", "link": 359}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 364}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 360}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": 370}, {"localized_name": "bypass", "name": "bypass", "type": "BOOLEAN", "widget": {"name": "bypass"}, "link": 363}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [357]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVImgToVideoInplace", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1, false]}, {"id": 118, "type": "Reroute", "pos": [-230, 4210], "size": [75, 26], "flags": {}, "order": 13, "mode": 0, "inputs": [{"name": "", "type": "*", "link": 303}], "outputs": [{"name": "", "type": "VAE", "links": [289, 291, 367]}], "properties": {"showOutputText": false, "horizontal": false}}, {"id": 151, "type": "LTXVImgToVideoInplace", "pos": [-20, 4070], "size": [270, 182], "flags": {}, "order": 38, "mode": 0, "inputs": [{"localized_name": "vae", "name": "vae", "type": "VAE", "link": 367}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 379}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 366}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": 371}, {"localized_name": "bypass", "name": "bypass", "type": "BOOLEAN", "widget": {"name": "bypass"}, "link": 368}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [365]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVImgToVideoInplace", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1, false]}, {"id": 104, "type": "LTXVCropGuides", "pos": [-10, 3840], "size": [240, 66], "flags": {}, "order": 19, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 310}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 312}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 270}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [281]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [282]}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "slot_index": 2, "links": [287]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.68", "Node name for S&R": "LTXVCropGuides", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 112, "type": "LTXVLatentUpsampler", "pos": [-10, 3960], "size": [260, 66], "flags": {}, "order": 25, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 287}, {"localized_name": "upscale_model", "name": "upscale_model", "type": "LATENT_UPSCALE_MODEL", "link": 288}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 289}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [366]}], "title": "spatial", "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVLatentUpsampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 132, "type": "LTXVAddGuide", "pos": [-600, 4420], "size": [270, 209.16666666666669], "flags": {}, "order": 33, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 313}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 314}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 328}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 357}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 348}, {"localized_name": "frame_idx", "name": "frame_idx", "type": "INT", "widget": {"name": "frame_idx"}, "link": null}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [309, 310]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [311, 312]}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [324]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "LTXVAddGuide", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, 1]}, {"id": 103, "type": "CheckpointLoaderSimple", "pos": [-1650, 3590], "size": [420, 98], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 385}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [325]}, {"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": []}, {"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [303, 328, 353, 359]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "CheckpointLoaderSimple", "models": [{"name": "ltx-2-19b-dev-fp8.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-dev-fp8.safetensors", "directory": "checkpoints"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-dev-fp8.safetensors"]}, {"id": 97, "type": "LTXAVTextEncoderLoader", "pos": [-1650, 4040], "size": [420, 106], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "text_encoder", "name": "text_encoder", "type": "COMBO", "widget": {"name": "text_encoder"}, "link": 389}, {"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 387}, {"localized_name": "device", "name": "device", "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [294, 295]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXAVTextEncoderLoader", "models": [{"name": "ltx-2-19b-dev-fp8.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-dev-fp8.safetensors", "directory": "checkpoints"}, {"name": "gemma_3_12B_it_fp4_mixed.safetensors", "url": "https://huggingface.co/Comfy-Org/ltx-2/resolve/main/split_files/text_encoders/gemma_3_12B_it_fp4_mixed.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["gemma_3_12B_it_fp4_mixed.safetensors", "ltx-2-19b-dev-fp8.safetensors", "default"]}, {"id": 105, "type": "LoraLoaderModelOnly", "pos": [-70, 3570], "size": [390, 82], "flags": {}, "order": 15, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 327}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 390}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [280]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "ltx-2-19b-distilled-lora-384.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-distilled-lora-384.safetensors", "directory": "loras"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-distilled-lora-384.safetensors", 1]}, {"id": 100, "type": "LatentUpscaleModelLoader", "pos": [-70, 3700], "size": [390, 60], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 391}], "outputs": [{"localized_name": "LATENT_UPSCALE_MODEL", "name": "LATENT_UPSCALE_MODEL", "type": "LATENT_UPSCALE_MODEL", "links": [288]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LatentUpscaleModelLoader", "models": [{"name": "ltx-2-spatial-upscaler-x2-1.0.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-spatial-upscaler-x2-1.0.safetensors", "directory": "latent_upscale_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-spatial-upscaler-x2-1.0.safetensors"]}, {"id": 154, "type": "MarkdownNote", "pos": [-1660, 4870], "size": [350, 170], "flags": {"collapsed": false}, "order": 10, "mode": 0, "inputs": [], "outputs": [], "title": "Frame Rate Note", "properties": {}, "widgets_values": ["Please make sure the frame rate value is the same in both boxes"], "color": "#222", "bgcolor": "#000"}, {"id": 155, "type": "LTXVAudioVAELoader", "pos": [-1640, 3910], "size": [400, 58], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 386}], "outputs": [{"localized_name": "Audio VAE", "name": "Audio VAE", "type": "VAE", "links": [383, 384]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "LTXVAudioVAELoader"}, "widgets_values": ["ltx-2-19b-dev-fp8.safetensors"]}], "groups": [{"id": 1, "title": "Model", "bounding": [-1660, 3440, 440, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Basic Sampling", "bounding": [-700, 3440, 570, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Prompt", "bounding": [-1180, 3440, 440, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 5, "title": "Latent", "bounding": [-1180, 4290, 1050, 680], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 9, "title": "Upscale Sampling(2x)", "bounding": [-100, 3440, 1090, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 6, "title": "Sampler", "bounding": [350, 3480, 620, 750], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 7, "title": "Model", "bounding": [-90, 3480, 430, 310], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 11, "title": "Frame rate", "bounding": [-1640, 4550, 290, 271.6], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 326, "origin_id": 134, "origin_slot": 0, "target_id": 93, "target_slot": 0, "type": "MODEL"}, {"id": 309, "origin_id": 132, "origin_slot": 0, "target_id": 93, "target_slot": 1, "type": "CONDITIONING"}, {"id": 311, "origin_id": 132, "origin_slot": 1, "target_id": 93, "target_slot": 2, "type": "CONDITIONING"}, {"id": 266, "origin_id": 122, "origin_slot": 1, "target_id": 101, "target_slot": 1, "type": "LATENT"}, {"id": 280, "origin_id": 105, "origin_slot": 0, "target_id": 108, "target_slot": 0, "type": "MODEL"}, {"id": 281, "origin_id": 104, "origin_slot": 0, "target_id": 108, "target_slot": 1, "type": "CONDITIONING"}, {"id": 282, "origin_id": 104, "origin_slot": 1, "target_id": 108, "target_slot": 2, "type": "CONDITIONING"}, {"id": 329, "origin_id": 110, "origin_slot": 2, "target_id": 111, "target_slot": 1, "type": "INT"}, {"id": 260, "origin_id": 126, "origin_slot": 0, "target_id": 123, "target_slot": 0, "type": "NOISE"}, {"id": 261, "origin_id": 93, "origin_slot": 0, "target_id": 123, "target_slot": 1, "type": "GUIDER"}, {"id": 262, "origin_id": 94, "origin_slot": 0, "target_id": 123, "target_slot": 2, "type": "SAMPLER"}, {"id": 263, "origin_id": 95, "origin_slot": 0, "target_id": 123, "target_slot": 3, "type": "SIGMAS"}, {"id": 323, "origin_id": 116, "origin_slot": 0, "target_id": 123, "target_slot": 4, "type": "LATENT"}, {"id": 296, "origin_id": 110, "origin_slot": 0, "target_id": 115, "target_slot": 0, "type": "INT"}, {"id": 297, "origin_id": 110, "origin_slot": 1, "target_id": 115, "target_slot": 1, "type": "INT"}, {"id": 330, "origin_id": 110, "origin_slot": 2, "target_id": 115, "target_slot": 2, "type": "INT"}, {"id": 325, "origin_id": 103, "origin_slot": 0, "target_id": 134, "target_slot": 0, "type": "MODEL"}, {"id": 292, "origin_id": 124, "origin_slot": 0, "target_id": 114, "target_slot": 0, "type": "CONDITIONING"}, {"id": 293, "origin_id": 119, "origin_slot": 0, "target_id": 114, "target_slot": 1, "type": "CONDITIONING"}, {"id": 294, "origin_id": 97, "origin_slot": 0, "target_id": 119, "target_slot": 0, "type": "CLIP"}, {"id": 324, "origin_id": 132, "origin_slot": 2, "target_id": 116, "target_slot": 0, "type": "LATENT"}, {"id": 300, "origin_id": 111, "origin_slot": 0, "target_id": 116, "target_slot": 1, "type": "LATENT"}, {"id": 313, "origin_id": 114, "origin_slot": 0, "target_id": 132, "target_slot": 0, "type": "CONDITIONING"}, {"id": 314, "origin_id": 114, "origin_slot": 1, "target_id": 132, "target_slot": 1, "type": "CONDITIONING"}, {"id": 328, "origin_id": 103, "origin_slot": 2, "target_id": 132, "target_slot": 2, "type": "VAE"}, {"id": 272, "origin_id": 123, "origin_slot": 0, "target_id": 122, "target_slot": 0, "type": "LATENT"}, {"id": 336, "origin_id": 107, "origin_slot": 1, "target_id": 138, "target_slot": 0, "type": "LATENT"}, {"id": 339, "origin_id": 139, "origin_slot": 0, "target_id": 106, "target_slot": 1, "type": "AUDIO"}, {"id": 295, "origin_id": 97, "origin_slot": 0, "target_id": 124, "target_slot": 0, "type": "CLIP"}, {"id": 303, "origin_id": 103, "origin_slot": 2, "target_id": 118, "target_slot": 0, "type": "VAE"}, {"id": 338, "origin_id": 138, "origin_slot": 1, "target_id": 139, "target_slot": 0, "type": "LATENT"}, {"id": 337, "origin_id": 138, "origin_slot": 0, "target_id": 113, "target_slot": 0, "type": "LATENT"}, {"id": 291, "origin_id": 118, "origin_slot": 0, "target_id": 113, "target_slot": 1, "type": "VAE"}, {"id": 276, "origin_id": 108, "origin_slot": 0, "target_id": 107, "target_slot": 1, "type": "GUIDER"}, {"id": 277, "origin_id": 98, "origin_slot": 0, "target_id": 107, "target_slot": 2, "type": "SAMPLER"}, {"id": 278, "origin_id": 99, "origin_slot": 0, "target_id": 107, "target_slot": 3, "type": "SIGMAS"}, {"id": 279, "origin_id": 101, "origin_slot": 0, "target_id": 107, "target_slot": 4, "type": "LATENT"}, {"id": 327, "origin_id": 134, "origin_slot": 0, "target_id": 105, "target_slot": 0, "type": "MODEL"}, {"id": 310, "origin_id": 132, "origin_slot": 0, "target_id": 104, "target_slot": 0, "type": "CONDITIONING"}, {"id": 312, "origin_id": 132, "origin_slot": 1, "target_id": 104, "target_slot": 1, "type": "CONDITIONING"}, {"id": 270, "origin_id": 122, "origin_slot": 0, "target_id": 104, "target_slot": 2, "type": "LATENT"}, {"id": 287, "origin_id": 104, "origin_slot": 2, "target_id": 112, "target_slot": 0, "type": "LATENT"}, {"id": 288, "origin_id": 100, "origin_slot": 0, "target_id": 112, "target_slot": 1, "type": "LATENT_UPSCALE_MODEL"}, {"id": 289, "origin_id": 118, "origin_slot": 0, "target_id": 112, "target_slot": 2, "type": "VAE"}, {"id": 322, "origin_id": 116, "origin_slot": 0, "target_id": 95, "target_slot": 0, "type": "LATENT"}, {"id": 304, "origin_id": 106, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 345, "origin_id": -10, "origin_slot": 0, "target_id": 124, "target_slot": 1, "type": "STRING"}, {"id": 347, "origin_id": 143, "origin_slot": 0, "target_id": 107, "target_slot": 0, "type": "NOISE"}, {"id": 348, "origin_id": -10, "origin_slot": 1, "target_id": 132, "target_slot": 4, "type": "IMAGE"}, {"id": 349, "origin_id": -10, "origin_slot": 1, "target_id": 110, "target_slot": 0, "type": "IMAGE"}, {"id": 351, "origin_id": 138, "origin_slot": 0, "target_id": 144, "target_slot": 0, "type": "LATENT"}, {"id": 352, "origin_id": 144, "origin_slot": 0, "target_id": 106, "target_slot": 0, "type": "IMAGE"}, {"id": 353, "origin_id": 103, "origin_slot": 2, "target_id": 144, "target_slot": 1, "type": "VAE"}, {"id": 354, "origin_id": 145, "origin_slot": 0, "target_id": 111, "target_slot": 2, "type": "INT"}, {"id": 355, "origin_id": 148, "origin_slot": 0, "target_id": 114, "target_slot": 2, "type": "FLOAT"}, {"id": 356, "origin_id": 148, "origin_slot": 0, "target_id": 106, "target_slot": 2, "type": "FLOAT"}, {"id": 357, "origin_id": 149, "origin_slot": 0, "target_id": 132, "target_slot": 3, "type": "LATENT"}, {"id": 359, "origin_id": 103, "origin_slot": 2, "target_id": 149, "target_slot": 0, "type": "VAE"}, {"id": 360, "origin_id": 115, "origin_slot": 0, "target_id": 149, "target_slot": 2, "type": "LATENT"}, {"id": 363, "origin_id": -10, "origin_slot": 3, "target_id": 149, "target_slot": 4, "type": "BOOLEAN"}, {"id": 364, "origin_id": -10, "origin_slot": 4, "target_id": 149, "target_slot": 1, "type": "IMAGE"}, {"id": 365, "origin_id": 151, "origin_slot": 0, "target_id": 101, "target_slot": 0, "type": "LATENT"}, {"id": 366, "origin_id": 112, "origin_slot": 0, "target_id": 151, "target_slot": 2, "type": "LATENT"}, {"id": 367, "origin_id": 118, "origin_slot": 0, "target_id": 151, "target_slot": 0, "type": "VAE"}, {"id": 368, "origin_id": -10, "origin_slot": 3, "target_id": 151, "target_slot": 4, "type": "BOOLEAN"}, {"id": 370, "origin_id": -10, "origin_slot": 2, "target_id": 149, "target_slot": 3, "type": "FLOAT"}, {"id": 371, "origin_id": -10, "origin_slot": 2, "target_id": 151, "target_slot": 3, "type": "FLOAT"}, {"id": 379, "origin_id": -10, "origin_slot": 4, "target_id": 151, "target_slot": 1, "type": "IMAGE"}, {"id": 383, "origin_id": 155, "origin_slot": 0, "target_id": 111, "target_slot": 0, "type": "VAE"}, {"id": 384, "origin_id": 155, "origin_slot": 0, "target_id": 139, "target_slot": 1, "type": "VAE"}, {"id": 385, "origin_id": -10, "origin_slot": 5, "target_id": 103, "target_slot": 0, "type": "COMBO"}, {"id": 386, "origin_id": -10, "origin_slot": 5, "target_id": 155, "target_slot": 0, "type": "COMBO"}, {"id": 387, "origin_id": -10, "origin_slot": 5, "target_id": 97, "target_slot": 1, "type": "COMBO"}, {"id": 388, "origin_id": -10, "origin_slot": 6, "target_id": 134, "target_slot": 1, "type": "COMBO"}, {"id": 389, "origin_id": -10, "origin_slot": 7, "target_id": 97, "target_slot": 0, "type": "COMBO"}, {"id": 390, "origin_id": -10, "origin_slot": 8, "target_id": 105, "target_slot": 1, "type": "COMBO"}, {"id": 391, "origin_id": -10, "origin_slot": 9, "target_id": 100, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Canny to video"}]}, "config": {}, "extra": {"workflowRendererVersion": "LG", "ds": {"scale": 0.7537190265006444, "offset": [-330.27244430536007, -3324.725077010053]}}, "version": 0.4} diff --git a/blueprints/Chromatic Aberration.json b/blueprints/Chromatic Aberration.json new file mode 100644 index 000000000..5513cc665 --- /dev/null +++ b/blueprints/Chromatic Aberration.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 19, "last_link_id": 0, "nodes": [{"id": 19, "type": "2c5ef154-2bde-496d-bc8b-9dcf42f2913f", "pos": [3710, -2070], "size": [260, 82], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Chromatic Aberration", "properties": {"proxyWidgets": [["17", "choice"], ["18", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "2c5ef154-2bde-496d-bc8b-9dcf42f2913f", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 18, "lastLinkId": 23, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Chromatic Aberration", "inputNode": {"id": -10, "bounding": [3270, -2050, 120, 60]}, "outputNode": {"id": -20, "bounding": [4260, -2050, 120, 60]}, "inputs": [{"id": "3b33ac46-93a6-4b1c-896a-ed6fbd24e59c", "name": "images.image0", "type": "IMAGE", "linkIds": [20], "localized_name": "images.image0", "label": "image", "pos": [3370, -2030]}], "outputs": [{"id": "abe7cd79-a87b-4bd0-8923-d79a57d81a6e", "name": "IMAGE0", "type": "IMAGE", "linkIds": [23], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [4280, -2030]}], "widgets": [], "nodes": [{"id": 16, "type": "GLSLShader", "pos": [3810, -2320], "size": [390, 212], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 20}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 22}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": 21}, {"label": "u_int1", "localized_name": "ints.u_int1", "name": "ints.u_int1", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [23]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Mode\nuniform float u_float0; // Amount (0 to 100)\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int MODE_LINEAR = 0;\nconst int MODE_RADIAL = 1;\nconst int MODE_BARREL = 2;\nconst int MODE_SWIRL = 3;\nconst int MODE_DIAGONAL = 4;\n\nconst float AMOUNT_SCALE = 0.0005;\nconst float RADIAL_MULT = 4.0;\nconst float BARREL_MULT = 8.0;\nconst float INV_SQRT2 = 0.70710678118;\n\nvoid main() {\n vec2 uv = v_texCoord;\n vec4 original = texture(u_image0, uv);\n\n float amount = u_float0 * AMOUNT_SCALE;\n\n if (amount < 0.000001) {\n fragColor = original;\n return;\n }\n\n // Aspect-corrected coordinates for circular effects\n float aspect = u_resolution.x / u_resolution.y;\n vec2 centered = uv - 0.5;\n vec2 corrected = vec2(centered.x * aspect, centered.y);\n float r = length(corrected);\n vec2 dir = r > 0.0001 ? corrected / r : vec2(0.0);\n vec2 offset = vec2(0.0);\n\n if (u_int0 == MODE_LINEAR) {\n // Horizontal shift (no aspect correction needed)\n offset = vec2(amount, 0.0);\n }\n else if (u_int0 == MODE_RADIAL) {\n // Outward from center, stronger at edges\n offset = dir * r * amount * RADIAL_MULT;\n offset.x /= aspect; // Convert back to UV space\n }\n else if (u_int0 == MODE_BARREL) {\n // Lens distortion simulation (r² falloff)\n offset = dir * r * r * amount * BARREL_MULT;\n offset.x /= aspect; // Convert back to UV space\n }\n else if (u_int0 == MODE_SWIRL) {\n // Perpendicular to radial (rotational aberration)\n vec2 perp = vec2(-dir.y, dir.x);\n offset = perp * r * amount * RADIAL_MULT;\n offset.x /= aspect; // Convert back to UV space\n }\n else if (u_int0 == MODE_DIAGONAL) {\n // 45° offset (no aspect correction needed)\n offset = vec2(amount, amount) * INV_SQRT2;\n }\n \n float red = texture(u_image0, uv + offset).r;\n float green = original.g;\n float blue = texture(u_image0, uv - offset).b;\n \n fragColor = vec4(red, green, blue, original.a);\n}", "from_input"]}, {"id": 18, "type": "PrimitiveFloat", "pos": [3810, -2430], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "amount", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [22]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 100, "step": 1}, "widgets_values": [30]}, {"id": 17, "type": "CustomCombo", "pos": [3520, -2320], "size": [270, 222], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "mode", "localized_name": "choice", "name": "choice", "type": "COMBO", "widget": {"name": "choice"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": null}, {"localized_name": "INDEX", "name": "INDEX", "type": "INT", "links": [21]}], "properties": {"Node name for S&R": "CustomCombo"}, "widgets_values": ["Linear", 0, "Linear", "Radial", "Barrel", "Swirl", "Diagonal", ""]}], "groups": [], "links": [{"id": 22, "origin_id": 18, "origin_slot": 0, "target_id": 16, "target_slot": 2, "type": "FLOAT"}, {"id": 21, "origin_id": 17, "origin_slot": 1, "target_id": 16, "target_slot": 4, "type": "INT"}, {"id": 20, "origin_id": -10, "origin_slot": 0, "target_id": 16, "target_slot": 0, "type": "IMAGE"}, {"id": 23, "origin_id": 16, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} diff --git a/blueprints/Color Adjustment.json b/blueprints/Color Adjustment.json new file mode 100644 index 000000000..47f3df783 --- /dev/null +++ b/blueprints/Color Adjustment.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 14, "last_link_id": 0, "nodes": [{"id": 14, "type": "36677b92-5dd8-47a5-9380-4da982c1894f", "pos": [3610, -2630], "size": [270, 150], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["4", "value"], ["5", "value"], ["7", "value"], ["6", "value"]]}, "widgets_values": [], "title": "Color Adjustment"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "36677b92-5dd8-47a5-9380-4da982c1894f", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 16, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Color Adjustment", "inputNode": {"id": -10, "bounding": [3110, -3560, 120, 60]}, "outputNode": {"id": -20, "bounding": [4070, -3560, 120, 60]}, "inputs": [{"id": "0431d493-5f28-4430-bd00-84733997fc08", "name": "images.image0", "type": "IMAGE", "linkIds": [29], "localized_name": "images.image0", "label": "image", "pos": [3210, -3540]}], "outputs": [{"id": "bee8ea06-a114-4612-8937-939f2c927bdb", "name": "IMAGE0", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [4090, -3540]}], "widgets": [], "nodes": [{"id": 15, "type": "GLSLShader", "pos": [3590, -3940], "size": [420, 252], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 29}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 34}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 30}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 31}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": 33}, {"label": "u_float4", "localized_name": "floats.u_float4", "name": "floats.u_float4", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [28]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // temperature (-100 to 100)\nuniform float u_float1; // tint (-100 to 100)\nuniform float u_float2; // vibrance (-100 to 100)\nuniform float u_float3; // saturation (-100 to 100)\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst float INPUT_SCALE = 0.01;\nconst float TEMP_TINT_PRIMARY = 0.3;\nconst float TEMP_TINT_SECONDARY = 0.15;\nconst float VIBRANCE_BOOST = 2.0;\nconst float SATURATION_BOOST = 2.0;\nconst float SKIN_PROTECTION = 0.5;\nconst float EPSILON = 0.001;\nconst vec3 LUMA_WEIGHTS = vec3(0.299, 0.587, 0.114);\n\nvoid main() {\n vec4 tex = texture(u_image0, v_texCoord);\n vec3 color = tex.rgb;\n \n // Scale inputs: -100/100 \u2192 -1/1\n float temperature = u_float0 * INPUT_SCALE;\n float tint = u_float1 * INPUT_SCALE;\n float vibrance = u_float2 * INPUT_SCALE;\n float saturation = u_float3 * INPUT_SCALE;\n \n // Temperature (warm/cool): positive = warm, negative = cool\n color.r += temperature * TEMP_TINT_PRIMARY;\n color.b -= temperature * TEMP_TINT_PRIMARY;\n \n // Tint (green/magenta): positive = green, negative = magenta\n color.g += tint * TEMP_TINT_PRIMARY;\n color.r -= tint * TEMP_TINT_SECONDARY;\n color.b -= tint * TEMP_TINT_SECONDARY;\n \n // Single clamp after temperature/tint\n color = clamp(color, 0.0, 1.0);\n \n // Vibrance with skin protection\n if (vibrance != 0.0) {\n float maxC = max(color.r, max(color.g, color.b));\n float minC = min(color.r, min(color.g, color.b));\n float sat = maxC - minC;\n float gray = dot(color, LUMA_WEIGHTS);\n \n if (vibrance < 0.0) {\n // Desaturate: -100 \u2192 gray\n color = mix(vec3(gray), color, 1.0 + vibrance);\n } else {\n // Boost less saturated colors more\n float vibranceAmt = vibrance * (1.0 - sat);\n \n // Branchless skin tone protection\n float isWarmTone = step(color.b, color.g) * step(color.g, color.r);\n float warmth = (color.r - color.b) / max(maxC, EPSILON);\n float skinTone = isWarmTone * warmth * sat * (1.0 - sat);\n vibranceAmt *= (1.0 - skinTone * SKIN_PROTECTION);\n \n color = mix(vec3(gray), color, 1.0 + vibranceAmt * VIBRANCE_BOOST);\n }\n }\n \n // Saturation\n if (saturation != 0.0) {\n float gray = dot(color, LUMA_WEIGHTS);\n float satMix = saturation < 0.0\n ? 1.0 + saturation // -100 \u2192 gray\n : 1.0 + saturation * SATURATION_BOOST; // +100 \u2192 3x boost\n color = mix(vec3(gray), color, satMix);\n }\n \n fragColor = vec4(clamp(color, 0.0, 1.0), tex.a);\n}", "from_input"]}, {"id": 6, "type": "PrimitiveFloat", "pos": [3290, -3610], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "vibrance", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [26, 31]}], "title": "Vibrance", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 7, "type": "PrimitiveFloat", "pos": [3290, -3720], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "saturation", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [33]}], "title": "Saturation", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 5, "type": "PrimitiveFloat", "pos": [3290, -3830], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "tint", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [30]}], "title": "Tint", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 255, 0]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 0, 255]}]}, "widgets_values": [0]}, {"id": 4, "type": "PrimitiveFloat", "pos": [3290, -3940], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "temperature", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [34]}], "title": "Temperature", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": -100, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [68, 136, 255]}, {"offset": 0.5, "color": [255, 255, 255]}, {"offset": 1, "color": [255, 136, 0]}]}, "widgets_values": [0]}], "groups": [], "links": [{"id": 34, "origin_id": 4, "origin_slot": 0, "target_id": 15, "target_slot": 2, "type": "FLOAT"}, {"id": 30, "origin_id": 5, "origin_slot": 0, "target_id": 15, "target_slot": 3, "type": "FLOAT"}, {"id": 31, "origin_id": 6, "origin_slot": 0, "target_id": 15, "target_slot": 4, "type": "FLOAT"}, {"id": 33, "origin_id": 7, "origin_slot": 0, "target_id": 15, "target_slot": 5, "type": "FLOAT"}, {"id": 29, "origin_id": -10, "origin_slot": 0, "target_id": 15, "target_slot": 0, "type": "IMAGE"}, {"id": 28, "origin_id": 15, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} diff --git a/blueprints/Depth to Image (Z-Image-Turbo).json b/blueprints/Depth to Image (Z-Image-Turbo).json new file mode 100644 index 000000000..baffc4fc9 --- /dev/null +++ b/blueprints/Depth to Image (Z-Image-Turbo).json @@ -0,0 +1 @@ +{"id": "e046dd74-e2a7-4f31-a75b-5e11a8c72d4e", "revision": 0, "last_node_id": 76, "last_link_id": 259, "nodes": [{"id": 13, "type": "d8492a46-9e6c-4917-b5ea-4273aabf5f51", "pos": [400, 3630], "size": [400, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "control image", "name": "image", "type": "IMAGE", "link": null}, {"label": "prompt", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}, {"name": "name", "type": "COMBO", "widget": {"name": "name"}, "link": null}, {"label": "lotus_model", "name": "unet_name_1", "type": "COMBO", "widget": {"name": "unet_name_1"}, "link": null}, {"label": "sd15_vae", "name": "vae_name_1", "type": "COMBO", "widget": {"name": "vae_name_1"}, "link": null}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": null}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "unet_name"], ["-1", "clip_name"], ["-1", "vae_name"], ["-1", "name"], ["-1", "unet_name_1"], ["-1", "vae_name_1"], ["7", "control_after_generate"], ["7", "seed"]], "cnr_id": "comfy-core", "ver": "0.11.0"}, "widgets_values": ["", "z_image_turbo_bf16.safetensors", "qwen_3_4b.safetensors", "ae.safetensors", "Z-Image-Turbo-Fun-Controlnet-Union.safetensors", "lotus-depth-d-v1-1.safetensors", "vae-ft-mse-840000-ema-pruned.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "d8492a46-9e6c-4917-b5ea-4273aabf5f51", "version": 1, "state": {"lastGroupId": 3, "lastNodeId": 76, "lastLinkId": 259, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Depth to Image (Z-Image-Turbo)", "inputNode": {"id": -10, "bounding": [27.60368520069494, 4936.043696127976, 120, 200]}, "outputNode": {"id": -20, "bounding": [1598.6038576146689, 4936.043696127976, 120, 60]}, "inputs": [{"id": "29ca271b-8f63-4e7b-a4b8-c9b4192ada0b", "name": "image", "type": "IMAGE", "linkIds": [25], "label": "control image", "pos": [127.60368520069494, 4956.043696127976]}, {"id": "b6549f90-39ee-4b79-9e00-af4d9df969fe", "name": "text", "type": "STRING", "linkIds": [16], "label": "prompt", "pos": [127.60368520069494, 4976.043696127976]}, {"id": "add4a703-1185-4848-9494-b27dd37ff434", "name": "unet_name", "type": "COMBO", "linkIds": [252], "pos": [127.60368520069494, 4996.043696127976]}, {"id": "03233f9e-df65-4e05-b5c5-34d83129e85e", "name": "clip_name", "type": "COMBO", "linkIds": [253], "pos": [127.60368520069494, 5016.043696127976]}, {"id": "0c643ffb-326d-40ca-8a89-ebc585cf5015", "name": "vae_name", "type": "COMBO", "linkIds": [254], "pos": [127.60368520069494, 5036.043696127976]}, {"id": "409cdebe-632b-410f-a66c-711c2a1527e1", "name": "name", "type": "COMBO", "linkIds": [255], "pos": [127.60368520069494, 5056.043696127976]}, {"id": "80e6915f-5d59-4d6b-a197-d8c565ad2922", "name": "unet_name_1", "type": "COMBO", "linkIds": [258], "label": "lotus_model", "pos": [127.60368520069494, 5076.043696127976]}, {"id": "4207ec84-4409-4816-8444-76062bf6310c", "name": "vae_name_1", "type": "COMBO", "linkIds": [259], "label": "sd15_vae", "pos": [127.60368520069494, 5096.043696127976]}], "outputs": [{"id": "47f9a22d-6619-4917-9447-a7d5d08dceb5", "name": "IMAGE", "type": "IMAGE", "linkIds": [18], "pos": [1618.6038576146689, 4956.043696127976]}], "widgets": [], "nodes": [{"id": 1, "type": "CLIPLoader", "pos": [228.60381716506714, 4700.188262345759], "size": [269.9479166666667, 106], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 253}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [14]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPLoader", "models": [{"name": "qwen_3_4b.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/text_encoders/qwen_3_4b.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_3_4b.safetensors", "lumina2", "default"]}, {"id": 2, "type": "UNETLoader", "pos": [228.60381716506714, 4550.188402733727], "size": [269.9479166666667, 82], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 252}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [9]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "UNETLoader", "models": [{"name": "z_image_turbo_bf16.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/diffusion_models/z_image_turbo_bf16.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["z_image_turbo_bf16.safetensors", "default"]}, {"id": 3, "type": "VAELoader", "pos": [228.60381716506714, 4880.188283008492], "size": [269.9479166666667, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 254}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [2, 11]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "VAELoader", "models": [{"name": "ae.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/vae/ae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ae.safetensors"]}, {"id": 4, "type": "ModelPatchLoader", "pos": [228.60381716506714, 5010.1883654774], "size": [269.9479166666667, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "name", "name": "name", "type": "COMBO", "widget": {"name": "name"}, "link": 255}], "outputs": [{"localized_name": "MODEL_PATCH", "name": "MODEL_PATCH", "type": "MODEL_PATCH", "links": [10]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "ModelPatchLoader", "models": [{"name": "Z-Image-Turbo-Fun-Controlnet-Union.safetensors", "url": "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union.safetensors", "directory": "model_patches"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["Z-Image-Turbo-Fun-Controlnet-Union.safetensors"]}, {"id": 6, "type": "ModelSamplingAuraFlow", "pos": [998.6041081931173, 4490.1880693746825], "size": [289.97395833333337, 58], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 3}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [4]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "ModelSamplingAuraFlow", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3]}, {"id": 7, "type": "KSampler", "pos": [998.6041081931173, 4600.188363442829], "size": [300, 262], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 4}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 5}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 6}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 7}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "KSampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize", 9, 1, "res_multistep", "simple", 1]}, {"id": 8, "type": "ConditioningZeroOut", "pos": [748.2706508086186, 5044.854997097082], "size": [204.134765625, 26], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 8}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [6]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "ConditioningZeroOut", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 10, "type": "EmptySD3LatentImage", "pos": [1028.2702326451792, 5334.855683329977], "size": [259.9479166666667, 106], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 12}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 13}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [7]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "EmptySD3LatentImage", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1024, 1024, 1]}, {"id": 5, "type": "VAEDecode", "pos": [1338.604012131086, 4500.188453282262], "size": [200, 46], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 1}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 2}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [18]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 9, "type": "QwenImageDiffsynthControlnet", "pos": [608.2704996459613, 5204.85528564724], "size": [289.97395833333337, 138], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 9}, {"localized_name": "model_patch", "name": "model_patch", "type": "MODEL_PATCH", "link": 10}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 11}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 248}, {"localized_name": "mask", "name": "mask", "shape": 7, "type": "MASK", "link": null}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [3]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.76", "Node name for S&R": "QwenImageDiffsynthControlnet", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1]}, {"id": 11, "type": "GetImageSize", "pos": [530, 5440], "size": [140, 66], "flags": {"collapsed": false}, "order": 10, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 247}], "outputs": [{"localized_name": "width", "name": "width", "type": "INT", "links": [12]}, {"localized_name": "height", "name": "height", "type": "INT", "links": [13]}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "links": null}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.76", "Node name for S&R": "GetImageSize", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 12, "type": "CLIPTextEncode", "pos": [548.2706278500244, 4544.854827124228], "size": [400, 420], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 14}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 16}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [5, 8]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 14, "type": "ImageScaleToTotalPixels", "pos": [90, 5180], "size": [270, 106], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 25}, {"localized_name": "upscale_method", "name": "upscale_method", "type": "COMBO", "widget": {"name": "upscale_method"}, "link": null}, {"localized_name": "megapixels", "name": "megapixels", "type": "FLOAT", "widget": {"name": "megapixels"}, "link": null}, {"localized_name": "resolution_steps", "name": "resolution_steps", "type": "INT", "widget": {"name": "resolution_steps"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [248, 250]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.0", "Node name for S&R": "ImageScaleToTotalPixels"}, "widgets_values": ["lanczos", 1, 1]}, {"id": 15, "type": "PreviewImage", "pos": [90, 5530], "size": [380, 260], "flags": {}, "order": 13, "mode": 4, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 251}], "outputs": [], "properties": {"cnr_id": "comfy-core", "ver": "0.11.0", "Node name for S&R": "PreviewImage"}, "widgets_values": []}, {"id": 76, "type": "458bdf3c-4b58-421c-af50-c9c663a4d74c", "pos": [90, 5340], "size": [400, 150], "flags": {}, "order": 14, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 250}, {"label": "depth_intensity", "name": "sigma", "type": "FLOAT", "widget": {"name": "sigma"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 258}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 259}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [247, 251]}], "properties": {"proxyWidgets": [["-1", "sigma"], ["-1", "unet_name"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [999.0000000000002, "lotus-depth-d-v1-1.safetensors", "vae-ft-mse-840000-ema-pruned.safetensors"]}], "groups": [{"id": 1, "title": "Prompt", "bounding": [530, 4470, 440, 630], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Models", "bounding": [210, 4470, 300, 640], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Apple ControlNet", "bounding": [530, 5120, 440, 260], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 1, "origin_id": 7, "origin_slot": 0, "target_id": 5, "target_slot": 0, "type": "LATENT"}, {"id": 2, "origin_id": 3, "origin_slot": 0, "target_id": 5, "target_slot": 1, "type": "VAE"}, {"id": 3, "origin_id": 9, "origin_slot": 0, "target_id": 6, "target_slot": 0, "type": "MODEL"}, {"id": 4, "origin_id": 6, "origin_slot": 0, "target_id": 7, "target_slot": 0, "type": "MODEL"}, {"id": 5, "origin_id": 12, "origin_slot": 0, "target_id": 7, "target_slot": 1, "type": "CONDITIONING"}, {"id": 6, "origin_id": 8, "origin_slot": 0, "target_id": 7, "target_slot": 2, "type": "CONDITIONING"}, {"id": 7, "origin_id": 10, "origin_slot": 0, "target_id": 7, "target_slot": 3, "type": "LATENT"}, {"id": 8, "origin_id": 12, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "CONDITIONING"}, {"id": 9, "origin_id": 2, "origin_slot": 0, "target_id": 9, "target_slot": 0, "type": "MODEL"}, {"id": 10, "origin_id": 4, "origin_slot": 0, "target_id": 9, "target_slot": 1, "type": "MODEL_PATCH"}, {"id": 11, "origin_id": 3, "origin_slot": 0, "target_id": 9, "target_slot": 2, "type": "VAE"}, {"id": 12, "origin_id": 11, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "INT"}, {"id": 13, "origin_id": 11, "origin_slot": 1, "target_id": 10, "target_slot": 1, "type": "INT"}, {"id": 14, "origin_id": 1, "origin_slot": 0, "target_id": 12, "target_slot": 0, "type": "CLIP"}, {"id": 16, "origin_id": -10, "origin_slot": 1, "target_id": 12, "target_slot": 1, "type": "STRING"}, {"id": 18, "origin_id": 5, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 25, "origin_id": -10, "origin_slot": 0, "target_id": 14, "target_slot": 0, "type": "IMAGE"}, {"id": 247, "origin_id": 76, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 248, "origin_id": 14, "origin_slot": 0, "target_id": 9, "target_slot": 3, "type": "IMAGE"}, {"id": 250, "origin_id": 14, "origin_slot": 0, "target_id": 76, "target_slot": 0, "type": "IMAGE"}, {"id": 251, "origin_id": 76, "origin_slot": 0, "target_id": 15, "target_slot": 0, "type": "IMAGE"}, {"id": 252, "origin_id": -10, "origin_slot": 2, "target_id": 2, "target_slot": 0, "type": "COMBO"}, {"id": 253, "origin_id": -10, "origin_slot": 3, "target_id": 1, "target_slot": 0, "type": "COMBO"}, {"id": 254, "origin_id": -10, "origin_slot": 4, "target_id": 3, "target_slot": 0, "type": "COMBO"}, {"id": 255, "origin_id": -10, "origin_slot": 5, "target_id": 4, "target_slot": 0, "type": "COMBO"}, {"id": 258, "origin_id": -10, "origin_slot": 6, "target_id": 76, "target_slot": 2, "type": "COMBO"}, {"id": 259, "origin_id": -10, "origin_slot": 7, "target_id": 76, "target_slot": 3, "type": "COMBO"}], "extra": {"ds": {"scale": 1.3889423076923078, "offset": [22.056074766355096, -3503.3333333333335]}, "frontendVersion": "1.37.10", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true}, "category": "Image generation and editing/Depth to image"}, {"id": "458bdf3c-4b58-421c-af50-c9c663a4d74c", "version": 1, "state": {"lastGroupId": 3, "lastNodeId": 76, "lastLinkId": 259, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image to Depth Map (Lotus)", "inputNode": {"id": -10, "bounding": [-60, -172.61268043518066, 126.625, 120]}, "outputNode": {"id": -20, "bounding": [1650, -172.61268043518066, 120, 60]}, "inputs": [{"id": "3bdd30c3-4ec9-485a-814b-e7d39fb6b5cc", "name": "pixels", "type": "IMAGE", "linkIds": [37], "localized_name": "pixels", "pos": [46.625, -152.61268043518066]}, {"id": "f9a1017c-f4b9-43b4-94c2-41c088b3a492", "name": "sigma", "type": "FLOAT", "linkIds": [243], "label": "depth_intensity", "pos": [46.625, -132.61268043518066]}, {"id": "d721b249-fd2a-441b-9a78-2805f04e2644", "name": "unet_name", "type": "COMBO", "linkIds": [256], "pos": [46.625, -112.61268043518066]}, {"id": "0430e2ea-f8b5-4191-9b72-b7d62176f97c", "name": "vae_name", "type": "COMBO", "linkIds": [257], "pos": [46.625, -92.61268043518066]}], "outputs": [{"id": "2ec278bd-0b66-4b30-9c5b-994d5f638214", "name": "IMAGE", "type": "IMAGE", "linkIds": [242], "localized_name": "IMAGE", "pos": [1670, -152.61268043518066]}], "widgets": [], "nodes": [{"id": 8, "type": "VAEDecode", "pos": [1380.0000135211146, -240.0000135211144], "size": [210, 60], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 232}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 240}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [35]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAEDecode", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 10, "type": "UNETLoader", "pos": [135.34178335388546, -290.1947851765315], "size": [305.9244791666667, 97.7734375], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 256}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [31, 241]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "UNETLoader", "models": [{"name": "lotus-depth-d-v1-1.safetensors", "url": "https://huggingface.co/Comfy-Org/lotus/resolve/main/lotus-depth-d-v1-1.safetensors", "directory": "diffusion_models"}], "widget_ue_connectable": {}}, "widgets_values": ["lotus-depth-d-v1-1.safetensors", "default"]}, {"id": 14, "type": "VAELoader", "pos": [134.53144605616137, -165.18194011768782], "size": [305.9244791666667, 68.88020833333334], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 257}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [38, 240]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAELoader", "models": [{"name": "vae-ft-mse-840000-ema-pruned.safetensors", "url": "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors", "directory": "vae"}], "widget_ue_connectable": {}}, "widgets_values": ["vae-ft-mse-840000-ema-pruned.safetensors"]}, {"id": 16, "type": "SamplerCustomAdvanced", "pos": [990.6585475753939, -319.91444852782104], "size": [355.1953125, 325.98958333333337], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 237}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 27}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 33}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 194}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 201}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "slot_index": 0, "links": [232]}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "slot_index": 1, "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "SamplerCustomAdvanced", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 18, "type": "DisableNoise", "pos": [730.4769792883567, -320.00005408445816], "size": [210, 40], "flags": {}, "order": 2, "mode": 0, "inputs": [], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "slot_index": 0, "links": [237]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "DisableNoise", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 19, "type": "BasicGuider", "pos": [730.2630921572128, -251.22541185314978], "size": [210, 60], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 241}, {"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 238}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "slot_index": 0, "links": [27]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "BasicGuider", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 20, "type": "BasicScheduler", "pos": [488.64457755981744, -147.67201223931278], "size": [210, 122.21354166666667], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 31}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "slot_index": 0, "links": [66]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "BasicScheduler", "widget_ue_connectable": {}}, "widgets_values": ["normal", 1, 1]}, {"id": 21, "type": "KSamplerSelect", "pos": [730.2630921572128, -161.22540847287118], "size": [210, 68.88020833333334], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "slot_index": 0, "links": [33]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "KSamplerSelect", "widget_ue_connectable": {}}, "widgets_values": ["euler"]}, {"id": 22, "type": "ImageInvert", "pos": [1373.3333333333335, -318.33333333333337], "size": [210, 40], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 35}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [242]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "ImageInvert", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 23, "type": "VAEEncode", "pos": [730.2630921572128, 38.774608428522015], "size": [210, 60], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 37}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 38}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [201]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAEEncode", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 28, "type": "SetFirstSigma", "pos": [730.2630921572128, -61.225357768691524], "size": [210, 66.66666666666667], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 66}, {"localized_name": "sigma", "name": "sigma", "type": "FLOAT", "widget": {"name": "sigma"}, "link": 243}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "slot_index": 0, "links": [194]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "SetFirstSigma", "widget_ue_connectable": {}}, "widgets_values": [999.0000000000002]}, {"id": 68, "type": "LotusConditioning", "pos": [489.99998478874613, -229.99996619721344], "size": [210, 40], "flags": {}, "order": 4, "mode": 0, "inputs": [], "outputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "slot_index": 0, "links": [238]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "LotusConditioning", "widget_ue_connectable": {}}, "widgets_values": []}], "groups": [{"id": 2, "title": "Models", "bounding": [123.33333333333334, -351.6666666666667, 323.4014831310574, 263.55972005884377], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 232, "origin_id": 16, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "LATENT"}, {"id": 240, "origin_id": 14, "origin_slot": 0, "target_id": 8, "target_slot": 1, "type": "VAE"}, {"id": 237, "origin_id": 18, "origin_slot": 0, "target_id": 16, "target_slot": 0, "type": "NOISE"}, {"id": 27, "origin_id": 19, "origin_slot": 0, "target_id": 16, "target_slot": 1, "type": "GUIDER"}, {"id": 33, "origin_id": 21, "origin_slot": 0, "target_id": 16, "target_slot": 2, "type": "SAMPLER"}, {"id": 194, "origin_id": 28, "origin_slot": 0, "target_id": 16, "target_slot": 3, "type": "SIGMAS"}, {"id": 201, "origin_id": 23, "origin_slot": 0, "target_id": 16, "target_slot": 4, "type": "LATENT"}, {"id": 241, "origin_id": 10, "origin_slot": 0, "target_id": 19, "target_slot": 0, "type": "MODEL"}, {"id": 238, "origin_id": 68, "origin_slot": 0, "target_id": 19, "target_slot": 1, "type": "CONDITIONING"}, {"id": 31, "origin_id": 10, "origin_slot": 0, "target_id": 20, "target_slot": 0, "type": "MODEL"}, {"id": 35, "origin_id": 8, "origin_slot": 0, "target_id": 22, "target_slot": 0, "type": "IMAGE"}, {"id": 38, "origin_id": 14, "origin_slot": 0, "target_id": 23, "target_slot": 1, "type": "VAE"}, {"id": 66, "origin_id": 20, "origin_slot": 0, "target_id": 28, "target_slot": 0, "type": "SIGMAS"}, {"id": 37, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 242, "origin_id": 22, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 243, "origin_id": -10, "origin_slot": 1, "target_id": 28, "target_slot": 1, "type": "FLOAT"}, {"id": 256, "origin_id": -10, "origin_slot": 2, "target_id": 10, "target_slot": 0, "type": "COMBO"}, {"id": 257, "origin_id": -10, "origin_slot": 3, "target_id": 14, "target_slot": 0, "type": "COMBO"}], "extra": {"ds": {"scale": 1.2354281696404266, "offset": [-114.15605447786857, -754.3368938705543]}, "workflowRendererVersion": "LG"}}]}, "config": {}, "extra": {"ds": {"scale": 0.7886233956111374, "offset": [741.6589462093539, -3278.0806447095165]}, "frontendVersion": "1.37.10", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true}, "version": 0.4} diff --git a/blueprints/Depth to Video (ltx 2.0).json b/blueprints/Depth to Video (ltx 2.0).json new file mode 100644 index 000000000..9656b6253 --- /dev/null +++ b/blueprints/Depth to Video (ltx 2.0).json @@ -0,0 +1 @@ +{"id": "ec176c82-4db5-4ab9-b5a0-8aa8e5684a81", "revision": 0, "last_node_id": 191, "last_link_id": 433, "nodes": [{"id": 143, "type": "68857357-cbc2-4c3a-a786-c3a58d43f9b1", "pos": [289.99998661973035, 3960.0002084505168], "size": [400, 500], "flags": {"collapsed": false}, "order": 0, "mode": 0, "inputs": [{"label": "prompt", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"label": "image_strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}, {"label": "disable_first_frame", "name": "bypass", "type": "BOOLEAN", "widget": {"name": "bypass"}, "link": null}, {"label": "depth reference video", "name": "video", "type": "VIDEO", "link": null}, {"label": "first frame", "name": "image_2", "type": "IMAGE", "link": null}, {"label": "width", "name": "resize_type.width", "type": "INT", "widget": {"name": "resize_type.width"}, "link": null}, {"label": "height", "name": "resize_type.height", "type": "INT", "widget": {"name": "resize_type.height"}, "link": null}, {"name": "length", "type": "INT", "widget": {"name": "length"}, "link": null}, {"name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": null}, {"name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": null}, {"name": "text_encoder", "type": "COMBO", "widget": {"name": "text_encoder"}, "link": null}, {"label": "distill_lora", "name": "lora_name_1", "type": "COMBO", "widget": {"name": "lora_name_1"}, "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}, {"label": "lotus_depth_model", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"label": "sd15_vae", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "bypass"], ["-1", "strength"], ["-1", "resize_type.width"], ["-1", "resize_type.height"], ["-1", "length"], ["126", "noise_seed"], ["143", "control_after_generate"], ["-1", "ckpt_name"], ["-1", "lora_name"], ["-1", "text_encoder"], ["-1", "lora_name_1"], ["-1", "model_name"], ["-1", "unet_name"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.7.0", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["", false, 1, 1280, 720, 121, null, null, "ltx-2-19b-dev-fp8.safetensors", "ltx-2-19b-ic-lora-depth-control.safetensors", "gemma_3_12B_it_fp4_mixed.safetensors", "ltx-2-19b-distilled-lora-384.safetensors", "ltx-2-spatial-upscaler-x2-1.0.safetensors", "lotus-depth-d-v1-1.safetensors", "vae-ft-mse-840000-ema-pruned.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "68857357-cbc2-4c3a-a786-c3a58d43f9b1", "version": 1, "state": {"lastGroupId": 16, "lastNodeId": 191, "lastLinkId": 433, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Depth to Video (LTX 2.0)", "inputNode": {"id": -10, "bounding": [-2730, 4020, 165.30859375, 340]}, "outputNode": {"id": -20, "bounding": [1750, 4090, 120, 60]}, "inputs": [{"id": "0f1d2f96-933a-4a7b-8f1a-7b49fc4ade09", "name": "text", "type": "STRING", "linkIds": [345], "label": "prompt", "pos": [-2584.69140625, 4040]}, {"id": "59430efe-1090-4e36-8afe-b21ce7f4268b", "name": "strength", "type": "FLOAT", "linkIds": [370, 371], "label": "image_strength", "pos": [-2584.69140625, 4060]}, {"id": "6145a9b9-68ed-4956-89f7-7a5ebdd5c99e", "name": "bypass", "type": "BOOLEAN", "linkIds": [363, 368], "label": "disable_first_frame", "pos": [-2584.69140625, 4080]}, {"id": "de434962-832a-485c-a016-869b3f2176ca", "name": "video", "type": "VIDEO", "linkIds": [419], "label": "depth reference video", "pos": [-2584.69140625, 4100]}, {"id": "a1189d3d-bbff-4933-875d-cffa58dd4cb0", "name": "image_2", "type": "IMAGE", "linkIds": [410], "label": "first frame", "pos": [-2584.69140625, 4120]}, {"id": "577dae4c-447b-4c84-9973-56381fdbc6a9", "name": "resize_type.width", "type": "INT", "linkIds": [420], "label": "width", "pos": [-2584.69140625, 4140]}, {"id": "fb30c570-128c-46b8-a140-054aff294edc", "name": "resize_type.height", "type": "INT", "linkIds": [421], "label": "height", "pos": [-2584.69140625, 4160]}, {"id": "33d5f598-00ae-4e2d-8eb2-2da23ae5ba46", "name": "length", "type": "INT", "linkIds": [422], "pos": [-2584.69140625, 4180]}, {"id": "68cc58b0-2013-4c3a-81ff-3d1e86232d76", "name": "ckpt_name", "type": "COMBO", "linkIds": [425, 433], "pos": [-2584.69140625, 4200]}, {"id": "0c65a06b-e12a-4298-8d81-69e57a123188", "name": "lora_name", "type": "COMBO", "linkIds": [426], "pos": [-2584.69140625, 4220]}, {"id": "eba96545-b8c6-4fba-b086-ddeeb4a9130d", "name": "text_encoder", "type": "COMBO", "linkIds": [427], "pos": [-2584.69140625, 4240]}, {"id": "848f9d82-3fde-4b95-b226-4b0db7082112", "name": "lora_name_1", "type": "COMBO", "linkIds": [429], "label": "distill_lora", "pos": [-2584.69140625, 4260]}, {"id": "32ace7dd-4da8-416b-b1e3-00652b3e6838", "name": "model_name", "type": "COMBO", "linkIds": [430], "pos": [-2584.69140625, 4280]}, {"id": "d6ad1978-71b6-425b-be13-c8f1e1d798d9", "name": "unet_name", "type": "COMBO", "linkIds": [431], "label": "lotus_depth_model", "pos": [-2584.69140625, 4300]}, {"id": "b0545a5d-65e8-4baa-a7be-d5f3d2b8b6e3", "name": "vae_name", "type": "COMBO", "linkIds": [432], "label": "sd15_vae", "pos": [-2584.69140625, 4320]}], "outputs": [{"id": "4e837941-de2d-4df8-8f94-686e24036897", "name": "VIDEO", "type": "VIDEO", "linkIds": [304], "localized_name": "VIDEO", "pos": [1770, 4110]}], "widgets": [], "nodes": [{"id": 93, "type": "CFGGuider", "pos": [-697.9999467324425, 3670.0001318308678], "size": [270, 106.66666666666667], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 326}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 309}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 311}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "links": [261]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "CFGGuider", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3]}, {"id": 94, "type": "KSamplerSelect", "pos": [-697.9999467324425, 3840.0000630985346], "size": [270, 68.88020833333334], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "links": [262]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "KSamplerSelect", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["euler"]}, {"id": 99, "type": "ManualSigmas", "pos": [409.9999946478922, 3850.0001667604133], "size": [270, 70], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "sigmas", "name": "sigmas", "type": "STRING", "widget": {"name": "sigmas"}, "link": null}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "links": [278]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "ManualSigmas", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["0.909375, 0.725, 0.421875, 0.0"]}, {"id": 101, "type": "LTXVConcatAVLatent", "pos": [409.9999946478922, 4100.000194929402], "size": [270, 110], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "link": 365}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "link": 266}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [279]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "LTXVConcatAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 108, "type": "CFGGuider", "pos": [409.9999946478922, 3700.00007661965], "size": [270, 106.66666666666667], "flags": {}, "order": 19, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 280}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 281}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 282}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "links": [276]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.71", "Node name for S&R": "CFGGuider", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1]}, {"id": 111, "type": "LTXVEmptyLatentAudio", "pos": [-1100.000003380279, 4810.000230985708], "size": [270, 120], "flags": {}, "order": 21, "mode": 0, "inputs": [{"localized_name": "audio_vae", "name": "audio_vae", "type": "VAE", "link": 285}, {"localized_name": "frames_number", "name": "frames_number", "type": "INT", "widget": {"name": "frames_number"}, "link": 329}, {"localized_name": "frame_rate", "name": "frame_rate", "type": "INT", "widget": {"name": "frame_rate"}, "link": 354}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "Latent", "name": "Latent", "type": "LATENT", "links": [300]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.68", "Node name for S&R": "LTXVEmptyLatentAudio", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [97, 25, 1]}, {"id": 123, "type": "SamplerCustomAdvanced", "pos": [-387.99998321128277, 3520.0000416901034], "size": [213.125, 120], "flags": {}, "order": 30, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 260}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 261}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 262}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 263}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 323}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "links": [272]}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.60", "Node name for S&R": "SamplerCustomAdvanced", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 114, "type": "LTXVConditioning", "pos": [-1134.000099492868, 4140.000243380063], "size": [270, 86.66666666666667], "flags": {}, "order": 24, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 292}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 293}, {"localized_name": "frame_rate", "name": "frame_rate", "type": "FLOAT", "widget": {"name": "frame_rate"}, "link": 355}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [313]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [314]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "LTXVConditioning", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [25]}, {"id": 119, "type": "CLIPTextEncode", "pos": [-1164.0000442816504, 3880.0001115491955], "size": [400, 200], "flags": {}, "order": 28, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 294}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [293]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["blurry, low quality, still frame, frames, watermark, overlay, titles, has blurbox, has subtitles"], "color": "#323", "bgcolor": "#535"}, {"id": 116, "type": "LTXVConcatAVLatent", "pos": [-519.9999874648, 4700.000189295605], "size": [187.5, 60], "flags": {}, "order": 26, "mode": 0, "inputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "link": 324}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "link": 300}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [322, 323]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVConcatAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 122, "type": "LTXVSeparateAVLatent", "pos": [-393.9999813239605, 3800.0000146478747], "size": [240, 60], "flags": {}, "order": 29, "mode": 0, "inputs": [{"localized_name": "av_latent", "name": "av_latent", "type": "LATENT", "link": 272}], "outputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "links": [270]}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "links": [266]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "LTXVSeparateAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 124, "type": "CLIPTextEncode", "pos": [-1174.9999569014471, 3514.0002724504593], "size": [410, 320], "flags": {}, "order": 31, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 295}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 345}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [292]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 98, "type": "KSamplerSelect", "pos": [409.9999946478922, 3980.00004957742], "size": [270, 68.88020833333334], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "links": [277]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "KSamplerSelect", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["gradient_estimation"]}, {"id": 95, "type": "LTXVScheduler", "pos": [-699.9999766197394, 3980.00004957742], "size": [270, 170], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "latent", "name": "latent", "shape": 7, "type": "LATENT", "link": 322}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "max_shift", "name": "max_shift", "type": "FLOAT", "widget": {"name": "max_shift"}, "link": null}, {"localized_name": "base_shift", "name": "base_shift", "type": "FLOAT", "widget": {"name": "base_shift"}, "link": null}, {"localized_name": "stretch", "name": "stretch", "type": "BOOLEAN", "widget": {"name": "stretch"}, "link": null}, {"localized_name": "terminal", "name": "terminal", "type": "FLOAT", "widget": {"name": "terminal"}, "link": null}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "links": [263]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "LTXVScheduler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [20, 2.05, 0.95, true, 0.1]}, {"id": 126, "type": "RandomNoise", "pos": [-697.9999467324425, 3520.0000416901034], "size": [270, 82], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "links": [260]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "RandomNoise", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "fixed"]}, {"id": 107, "type": "SamplerCustomAdvanced", "pos": [709.9999918309934, 3570.000193802643], "size": [212.3828125, 120], "flags": {}, "order": 18, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 347}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 276}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 277}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 278}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 279}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "links": []}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "links": [336]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "SamplerCustomAdvanced", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 143, "type": "RandomNoise", "pos": [409.9999946478922, 3570.000193802643], "size": [270, 82], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "links": [347]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "RandomNoise", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize"]}, {"id": 139, "type": "LTXVAudioVAEDecode", "pos": [1129.9999512676497, 3840.0000630985346], "size": [240, 60], "flags": {}, "order": 35, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 338}, {"label": "Audio VAE", "localized_name": "audio_vae", "name": "audio_vae", "type": "VAE", "link": 340}], "outputs": [{"localized_name": "Audio", "name": "Audio", "type": "AUDIO", "links": [339]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVAudioVAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 134, "type": "LoraLoaderModelOnly", "pos": [-1650.0000287323687, 3760.0003323940673], "size": [420, 95.546875], "flags": {}, "order": 33, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 325}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 426}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [326, 327]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "ltx-2-19b-ic-lora-depth-control.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Depth-Control/resolve/main/ltx-2-19b-ic-lora-depth-control.safetensors", "directory": "loras"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-ic-lora-depth-control.safetensors", 1], "color": "#322", "bgcolor": "#533"}, {"id": 138, "type": "LTXVSeparateAVLatent", "pos": [730.0000160563236, 3730.0000214084316], "size": [193.2916015625, 60], "flags": {}, "order": 34, "mode": 0, "inputs": [{"localized_name": "av_latent", "name": "av_latent", "type": "LATENT", "link": 336}], "outputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "links": [337, 351]}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "links": [338]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "LTXVSeparateAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 144, "type": "VAEDecodeTiled", "pos": [1119.9999391549845, 3640.000187042085], "size": [270, 150], "flags": {}, "order": 36, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 351}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 353}, {"localized_name": "tile_size", "name": "tile_size", "type": "INT", "widget": {"name": "tile_size"}, "link": null}, {"localized_name": "overlap", "name": "overlap", "type": "INT", "widget": {"name": "overlap"}, "link": null}, {"localized_name": "temporal_size", "name": "temporal_size", "type": "INT", "widget": {"name": "temporal_size"}, "link": null}, {"localized_name": "temporal_overlap", "name": "temporal_overlap", "type": "INT", "widget": {"name": "temporal_overlap"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [352]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "VAEDecodeTiled", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [512, 64, 4096, 8]}, {"id": 113, "type": "VAEDecode", "pos": [1129.9999512676497, 3530.000145351982], "size": [240, 60], "flags": {}, "order": 23, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 337}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 291}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 145, "type": "PrimitiveInt", "pos": [-1630.0000045070383, 4620.0000923942835], "size": [270, 82], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "value", "name": "value", "type": "INT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "INT", "name": "INT", "type": "INT", "links": [354]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "PrimitiveInt", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [24, "fixed"]}, {"id": 148, "type": "PrimitiveFloat", "pos": [-1630.0000045070383, 4749.99997521129], "size": [270, 66.66666666666667], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [355, 356]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "PrimitiveFloat", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [24]}, {"id": 115, "type": "EmptyLTXVLatentVideo", "pos": [-1100.000003380279, 4609.999988732406], "size": [270, 146.66666666666669], "flags": {}, "order": 25, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 296}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 297}, {"localized_name": "length", "name": "length", "type": "INT", "widget": {"name": "length"}, "link": 330}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [360]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.60", "Node name for S&R": "EmptyLTXVLatentVideo", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [768, 512, 97, 1]}, {"id": 149, "type": "LTXVImgToVideoInplace", "pos": [-1089.9999912676137, 4400.000009014077], "size": [270, 151.9921875], "flags": {}, "order": 37, "mode": 0, "inputs": [{"localized_name": "vae", "name": "vae", "type": "VAE", "link": 359}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 417}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 360}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": 370}, {"localized_name": "bypass", "name": "bypass", "type": "BOOLEAN", "widget": {"name": "bypass"}, "link": 363}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [357]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVImgToVideoInplace", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1, false]}, {"id": 118, "type": "Reroute", "pos": [-229.99999095071237, 4210.000236619506], "size": [75, 26], "flags": {}, "order": 27, "mode": 0, "inputs": [{"name": "", "type": "*", "link": 303}], "outputs": [{"name": "", "type": "VAE", "links": [289, 291, 367]}], "properties": {"showOutputText": false, "horizontal": false}}, {"id": 151, "type": "LTXVImgToVideoInplace", "pos": [-19.999999788732577, 4070.0002501406198], "size": [270, 181.9921875], "flags": {}, "order": 38, "mode": 0, "inputs": [{"localized_name": "vae", "name": "vae", "type": "VAE", "link": 367}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 410}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 366}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": 371}, {"localized_name": "bypass", "name": "bypass", "type": "BOOLEAN", "widget": {"name": "bypass"}, "link": 368}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [365]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVImgToVideoInplace", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1, false]}, {"id": 104, "type": "LTXVCropGuides", "pos": [-9.999999119719098, 3840.0000630985346], "size": [240, 80], "flags": {}, "order": 15, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 310}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 312}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 270}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [281]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [282]}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "slot_index": 2, "links": [287]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.68", "Node name for S&R": "LTXVCropGuides", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 112, "type": "LTXVLatentUpsampler", "pos": [-9.999999119719098, 3960.0002084505168], "size": [260, 80], "flags": {}, "order": 22, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 287}, {"localized_name": "upscale_model", "name": "upscale_model", "type": "LATENT_UPSCALE_MODEL", "link": 288}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 289}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [366]}], "title": "spatial", "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVLatentUpsampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 132, "type": "LTXVAddGuide", "pos": [-599.9999928169079, 4420.000216337834], "size": [270, 209.16666666666669], "flags": {}, "order": 32, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 313}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 314}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 328}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 357}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 418}, {"localized_name": "frame_idx", "name": "frame_idx", "type": "INT", "widget": {"name": "frame_idx"}, "link": null}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [309, 310]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [311, 312]}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [324]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "LTXVAddGuide", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, 1]}, {"id": 96, "type": "LTXVAudioVAELoader", "pos": [-1650.0000287323687, 3910.000056337978], "size": [420, 68.88020833333334], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 377}], "outputs": [{"localized_name": "Audio VAE", "name": "Audio VAE", "type": "VAE", "links": [285, 340]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.68", "Node name for S&R": "LTXVAudioVAELoader", "models": [{"name": "ltx-2-19b-dev-fp8.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-dev-fp8.safetensors", "directory": "checkpoints"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-dev-fp8.safetensors"]}, {"id": 103, "type": "CheckpointLoaderSimple", "pos": [-1650.0000287323687, 3590.0000349295465], "size": [420, 108.88020833333334], "flags": {}, "order": 14, "mode": 0, "inputs": [{"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 425}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [325]}, {"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": []}, {"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [303, 328, 353, 359]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "CheckpointLoaderSimple", "models": [{"name": "ltx-2-19b-dev-fp8.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-dev-fp8.safetensors", "directory": "checkpoints"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-dev-fp8.safetensors"]}, {"id": 105, "type": "LoraLoaderModelOnly", "pos": [-69.99999741197416, 3570.000193802643], "size": [390, 95.546875], "flags": {}, "order": 16, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 327}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 429}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [280]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "ltx-2-19b-distilled-lora-384.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-distilled-lora-384.safetensors", "directory": "loras"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-distilled-lora-384.safetensors", 1]}, {"id": 100, "type": "LatentUpscaleModelLoader", "pos": [-69.99999741197416, 3700.00007661965], "size": [390, 68.88020833333334], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 430}], "outputs": [{"localized_name": "LATENT_UPSCALE_MODEL", "name": "LATENT_UPSCALE_MODEL", "type": "LATENT_UPSCALE_MODEL", "links": [288]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LatentUpscaleModelLoader", "models": [{"name": "ltx-2-spatial-upscaler-x2-1.0.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-spatial-upscaler-x2-1.0.safetensors", "directory": "latent_upscale_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-spatial-upscaler-x2-1.0.safetensors"]}, {"id": 110, "type": "GetImageSize", "pos": [-1630.0000045070383, 4450.000161126616], "size": [260, 80], "flags": {}, "order": 20, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 416}], "outputs": [{"localized_name": "width", "name": "width", "type": "INT", "links": [296]}, {"localized_name": "height", "name": "height", "type": "INT", "links": [297]}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "links": [329, 330]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "GetImageSize", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 106, "type": "CreateVideo", "pos": [1419.9999363380857, 3760.0003323940673], "size": [270, 86.66666666666667], "flags": {}, "order": 17, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 352}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 339}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 356}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [304]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "CreateVideo", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [25]}, {"id": 187, "type": "ImageFromBatch", "pos": [-2310.000095774562, 3689.999972957771], "size": [260, 93.33333333333334], "flags": {}, "order": 39, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 412}, {"localized_name": "batch_index", "name": "batch_index", "type": "INT", "widget": {"name": "batch_index"}, "link": null}, {"localized_name": "length", "name": "length", "type": "INT", "widget": {"name": "length"}, "link": 422}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [415]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "ImageFromBatch", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, 121]}, {"id": 191, "type": "ResizeImageMaskNode", "pos": [-2320.0000163380137, 3850.0001667604133], "size": [284.375, 154], "flags": {}, "order": 43, "mode": 0, "inputs": [{"localized_name": "input", "name": "input", "type": "IMAGE,MASK", "link": 415}, {"localized_name": "resize_type", "name": "resize_type", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "resize_type"}, "link": null}, {"localized_name": "width", "name": "resize_type.width", "type": "INT", "widget": {"name": "resize_type.width"}, "link": 420}, {"localized_name": "height", "name": "resize_type.height", "type": "INT", "widget": {"name": "resize_type.height"}, "link": 421}, {"localized_name": "crop", "name": "resize_type.crop", "type": "COMBO", "widget": {"name": "resize_type.crop"}, "link": null}, {"localized_name": "scale_method", "name": "scale_method", "type": "COMBO", "widget": {"name": "scale_method"}, "link": null}], "outputs": [{"localized_name": "resized", "name": "resized", "type": "IMAGE", "links": [413]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "ResizeImageMaskNode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["scale dimensions", 1280, 720, "center", "lanczos"]}, {"id": 188, "type": "GetVideoComponents", "pos": [-2320.0000163380137, 3520.0000416901034], "size": [280, 80], "flags": {"collapsed": false}, "order": 40, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 419}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [412]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": []}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "GetVideoComponents", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 189, "type": "ImageScaleBy", "pos": [-1990.0000743661303, 3670.0001318308678], "size": [280, 125.546875], "flags": {}, "order": 41, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 413}, {"localized_name": "upscale_method", "name": "upscale_method", "type": "COMBO", "widget": {"name": "upscale_method"}, "link": null}, {"localized_name": "scale_by", "name": "scale_by", "type": "FLOAT", "widget": {"name": "scale_by"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [414]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "ImageScaleBy", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["lanczos", 0.5]}, {"id": 154, "type": "MarkdownNote", "pos": [-1659.9999492958204, 4870.000120563272], "size": [350, 170], "flags": {"collapsed": false}, "order": 7, "mode": 0, "inputs": [], "outputs": [], "title": "Frame Rate Note", "properties": {}, "widgets_values": ["Please make sure the frame rate value is the same in both boxes"], "color": "#222", "bgcolor": "#000"}, {"id": 190, "type": "38b60539-50a7-42f9-a5fe-bdeca26272e2", "pos": [-1999.9999949295823, 3910.000056337978], "size": [310, 106], "flags": {}, "order": 42, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 414}, {"label": "depth_intensity", "name": "sigma", "type": "FLOAT", "widget": {"name": "sigma"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 431}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 432}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [416, 417, 418]}], "properties": {"proxyWidgets": [["-1", "sigma"], ["-1", "unet_name"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [999.0000000000002, "lotus-depth-d-v1-1.safetensors", "vae-ft-mse-840000-ema-pruned.safetensors"], "color": "#322", "bgcolor": "#533"}, {"id": 97, "type": "LTXAVTextEncoderLoader", "pos": [-1650.0000287323687, 4040.0003053518376], "size": [420, 124.44010416666667], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "text_encoder", "name": "text_encoder", "type": "COMBO", "widget": {"name": "text_encoder"}, "link": 427}, {"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 433}, {"localized_name": "device", "name": "device", "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [294, 295]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXAVTextEncoderLoader", "models": [{"name": "ltx-2-19b-dev-fp8.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-dev-fp8.safetensors", "directory": "checkpoints"}, {"name": "gemma_3_12B_it_fp4_mixed.safetensors", "url": "https://huggingface.co/Comfy-Org/ltx-2/resolve/main/split_files/text_encoders/gemma_3_12B_it_fp4_mixed.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["gemma_3_12B_it_fp4_mixed.safetensors", "ltx-2-19b-dev-fp8.safetensors", "default"]}], "groups": [{"id": 1, "title": "Model", "bounding": [-1660, 3440, 440, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Basic Sampling", "bounding": [-700, 3440, 570, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Prompt", "bounding": [-1180, 3440, 440, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 5, "title": "Latent", "bounding": [-1180, 4290, 1050, 680], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 9, "title": "Upscale Sampling(2x)", "bounding": [-100, 3440, 1090, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 6, "title": "Sampler", "bounding": [350, 3480, 620, 750], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 7, "title": "Model", "bounding": [-90, 3480, 430, 310], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 11, "title": "Frame rate", "bounding": [-1640, 4550, 290, 271.6], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 16, "title": "Video Preprocess", "bounding": [-2330, 3450, 650, 567.6], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 15, "title": "video length", "bounding": [-2320, 3620, 290, 180], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 326, "origin_id": 134, "origin_slot": 0, "target_id": 93, "target_slot": 0, "type": "MODEL"}, {"id": 309, "origin_id": 132, "origin_slot": 0, "target_id": 93, "target_slot": 1, "type": "CONDITIONING"}, {"id": 311, "origin_id": 132, "origin_slot": 1, "target_id": 93, "target_slot": 2, "type": "CONDITIONING"}, {"id": 266, "origin_id": 122, "origin_slot": 1, "target_id": 101, "target_slot": 1, "type": "LATENT"}, {"id": 280, "origin_id": 105, "origin_slot": 0, "target_id": 108, "target_slot": 0, "type": "MODEL"}, {"id": 281, "origin_id": 104, "origin_slot": 0, "target_id": 108, "target_slot": 1, "type": "CONDITIONING"}, {"id": 282, "origin_id": 104, "origin_slot": 1, "target_id": 108, "target_slot": 2, "type": "CONDITIONING"}, {"id": 285, "origin_id": 96, "origin_slot": 0, "target_id": 111, "target_slot": 0, "type": "VAE"}, {"id": 329, "origin_id": 110, "origin_slot": 2, "target_id": 111, "target_slot": 1, "type": "INT"}, {"id": 260, "origin_id": 126, "origin_slot": 0, "target_id": 123, "target_slot": 0, "type": "NOISE"}, {"id": 261, "origin_id": 93, "origin_slot": 0, "target_id": 123, "target_slot": 1, "type": "GUIDER"}, {"id": 262, "origin_id": 94, "origin_slot": 0, "target_id": 123, "target_slot": 2, "type": "SAMPLER"}, {"id": 263, "origin_id": 95, "origin_slot": 0, "target_id": 123, "target_slot": 3, "type": "SIGMAS"}, {"id": 323, "origin_id": 116, "origin_slot": 0, "target_id": 123, "target_slot": 4, "type": "LATENT"}, {"id": 296, "origin_id": 110, "origin_slot": 0, "target_id": 115, "target_slot": 0, "type": "INT"}, {"id": 297, "origin_id": 110, "origin_slot": 1, "target_id": 115, "target_slot": 1, "type": "INT"}, {"id": 330, "origin_id": 110, "origin_slot": 2, "target_id": 115, "target_slot": 2, "type": "INT"}, {"id": 325, "origin_id": 103, "origin_slot": 0, "target_id": 134, "target_slot": 0, "type": "MODEL"}, {"id": 292, "origin_id": 124, "origin_slot": 0, "target_id": 114, "target_slot": 0, "type": "CONDITIONING"}, {"id": 293, "origin_id": 119, "origin_slot": 0, "target_id": 114, "target_slot": 1, "type": "CONDITIONING"}, {"id": 294, "origin_id": 97, "origin_slot": 0, "target_id": 119, "target_slot": 0, "type": "CLIP"}, {"id": 324, "origin_id": 132, "origin_slot": 2, "target_id": 116, "target_slot": 0, "type": "LATENT"}, {"id": 300, "origin_id": 111, "origin_slot": 0, "target_id": 116, "target_slot": 1, "type": "LATENT"}, {"id": 313, "origin_id": 114, "origin_slot": 0, "target_id": 132, "target_slot": 0, "type": "CONDITIONING"}, {"id": 314, "origin_id": 114, "origin_slot": 1, "target_id": 132, "target_slot": 1, "type": "CONDITIONING"}, {"id": 328, "origin_id": 103, "origin_slot": 2, "target_id": 132, "target_slot": 2, "type": "VAE"}, {"id": 272, "origin_id": 123, "origin_slot": 0, "target_id": 122, "target_slot": 0, "type": "LATENT"}, {"id": 336, "origin_id": 107, "origin_slot": 1, "target_id": 138, "target_slot": 0, "type": "LATENT"}, {"id": 339, "origin_id": 139, "origin_slot": 0, "target_id": 106, "target_slot": 1, "type": "AUDIO"}, {"id": 295, "origin_id": 97, "origin_slot": 0, "target_id": 124, "target_slot": 0, "type": "CLIP"}, {"id": 303, "origin_id": 103, "origin_slot": 2, "target_id": 118, "target_slot": 0, "type": "VAE"}, {"id": 338, "origin_id": 138, "origin_slot": 1, "target_id": 139, "target_slot": 0, "type": "LATENT"}, {"id": 340, "origin_id": 96, "origin_slot": 0, "target_id": 139, "target_slot": 1, "type": "VAE"}, {"id": 337, "origin_id": 138, "origin_slot": 0, "target_id": 113, "target_slot": 0, "type": "LATENT"}, {"id": 291, "origin_id": 118, "origin_slot": 0, "target_id": 113, "target_slot": 1, "type": "VAE"}, {"id": 276, "origin_id": 108, "origin_slot": 0, "target_id": 107, "target_slot": 1, "type": "GUIDER"}, {"id": 277, "origin_id": 98, "origin_slot": 0, "target_id": 107, "target_slot": 2, "type": "SAMPLER"}, {"id": 278, "origin_id": 99, "origin_slot": 0, "target_id": 107, "target_slot": 3, "type": "SIGMAS"}, {"id": 279, "origin_id": 101, "origin_slot": 0, "target_id": 107, "target_slot": 4, "type": "LATENT"}, {"id": 327, "origin_id": 134, "origin_slot": 0, "target_id": 105, "target_slot": 0, "type": "MODEL"}, {"id": 310, "origin_id": 132, "origin_slot": 0, "target_id": 104, "target_slot": 0, "type": "CONDITIONING"}, {"id": 312, "origin_id": 132, "origin_slot": 1, "target_id": 104, "target_slot": 1, "type": "CONDITIONING"}, {"id": 270, "origin_id": 122, "origin_slot": 0, "target_id": 104, "target_slot": 2, "type": "LATENT"}, {"id": 287, "origin_id": 104, "origin_slot": 2, "target_id": 112, "target_slot": 0, "type": "LATENT"}, {"id": 288, "origin_id": 100, "origin_slot": 0, "target_id": 112, "target_slot": 1, "type": "LATENT_UPSCALE_MODEL"}, {"id": 289, "origin_id": 118, "origin_slot": 0, "target_id": 112, "target_slot": 2, "type": "VAE"}, {"id": 322, "origin_id": 116, "origin_slot": 0, "target_id": 95, "target_slot": 0, "type": "LATENT"}, {"id": 304, "origin_id": 106, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 345, "origin_id": -10, "origin_slot": 0, "target_id": 124, "target_slot": 1, "type": "STRING"}, {"id": 347, "origin_id": 143, "origin_slot": 0, "target_id": 107, "target_slot": 0, "type": "NOISE"}, {"id": 351, "origin_id": 138, "origin_slot": 0, "target_id": 144, "target_slot": 0, "type": "LATENT"}, {"id": 352, "origin_id": 144, "origin_slot": 0, "target_id": 106, "target_slot": 0, "type": "IMAGE"}, {"id": 353, "origin_id": 103, "origin_slot": 2, "target_id": 144, "target_slot": 1, "type": "VAE"}, {"id": 354, "origin_id": 145, "origin_slot": 0, "target_id": 111, "target_slot": 2, "type": "INT"}, {"id": 355, "origin_id": 148, "origin_slot": 0, "target_id": 114, "target_slot": 2, "type": "FLOAT"}, {"id": 356, "origin_id": 148, "origin_slot": 0, "target_id": 106, "target_slot": 2, "type": "FLOAT"}, {"id": 357, "origin_id": 149, "origin_slot": 0, "target_id": 132, "target_slot": 3, "type": "LATENT"}, {"id": 359, "origin_id": 103, "origin_slot": 2, "target_id": 149, "target_slot": 0, "type": "VAE"}, {"id": 360, "origin_id": 115, "origin_slot": 0, "target_id": 149, "target_slot": 2, "type": "LATENT"}, {"id": 363, "origin_id": -10, "origin_slot": 2, "target_id": 149, "target_slot": 4, "type": "BOOLEAN"}, {"id": 365, "origin_id": 151, "origin_slot": 0, "target_id": 101, "target_slot": 0, "type": "LATENT"}, {"id": 366, "origin_id": 112, "origin_slot": 0, "target_id": 151, "target_slot": 2, "type": "LATENT"}, {"id": 367, "origin_id": 118, "origin_slot": 0, "target_id": 151, "target_slot": 0, "type": "VAE"}, {"id": 368, "origin_id": -10, "origin_slot": 2, "target_id": 151, "target_slot": 4, "type": "BOOLEAN"}, {"id": 370, "origin_id": -10, "origin_slot": 1, "target_id": 149, "target_slot": 3, "type": "FLOAT"}, {"id": 371, "origin_id": -10, "origin_slot": 1, "target_id": 151, "target_slot": 3, "type": "FLOAT"}, {"id": 377, "origin_id": -10, "origin_slot": 6, "target_id": 96, "target_slot": 0, "type": "COMBO"}, {"id": 410, "origin_id": -10, "origin_slot": 4, "target_id": 151, "target_slot": 1, "type": "IMAGE"}, {"id": 412, "origin_id": 188, "origin_slot": 0, "target_id": 187, "target_slot": 0, "type": "IMAGE"}, {"id": 413, "origin_id": 191, "origin_slot": 0, "target_id": 189, "target_slot": 0, "type": "IMAGE"}, {"id": 414, "origin_id": 189, "origin_slot": 0, "target_id": 190, "target_slot": 0, "type": "IMAGE"}, {"id": 415, "origin_id": 187, "origin_slot": 0, "target_id": 191, "target_slot": 0, "type": "IMAGE"}, {"id": 416, "origin_id": 190, "origin_slot": 0, "target_id": 110, "target_slot": 0, "type": "IMAGE"}, {"id": 417, "origin_id": 190, "origin_slot": 0, "target_id": 149, "target_slot": 1, "type": "IMAGE"}, {"id": 418, "origin_id": 190, "origin_slot": 0, "target_id": 132, "target_slot": 4, "type": "IMAGE"}, {"id": 419, "origin_id": -10, "origin_slot": 3, "target_id": 188, "target_slot": 0, "type": "VIDEO"}, {"id": 420, "origin_id": -10, "origin_slot": 5, "target_id": 191, "target_slot": 2, "type": "INT"}, {"id": 421, "origin_id": -10, "origin_slot": 6, "target_id": 191, "target_slot": 3, "type": "INT"}, {"id": 422, "origin_id": -10, "origin_slot": 7, "target_id": 187, "target_slot": 2, "type": "INT"}, {"id": 425, "origin_id": -10, "origin_slot": 8, "target_id": 103, "target_slot": 0, "type": "COMBO"}, {"id": 426, "origin_id": -10, "origin_slot": 9, "target_id": 134, "target_slot": 1, "type": "COMBO"}, {"id": 427, "origin_id": -10, "origin_slot": 10, "target_id": 97, "target_slot": 0, "type": "COMBO"}, {"id": 429, "origin_id": -10, "origin_slot": 11, "target_id": 105, "target_slot": 1, "type": "COMBO"}, {"id": 430, "origin_id": -10, "origin_slot": 12, "target_id": 100, "target_slot": 0, "type": "COMBO"}, {"id": 431, "origin_id": -10, "origin_slot": 13, "target_id": 190, "target_slot": 2, "type": "COMBO"}, {"id": 432, "origin_id": -10, "origin_slot": 14, "target_id": 190, "target_slot": 3, "type": "COMBO"}, {"id": 433, "origin_id": -10, "origin_slot": 8, "target_id": 97, "target_slot": 1, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Depth to video"}, {"id": "38b60539-50a7-42f9-a5fe-bdeca26272e2", "version": 1, "state": {"lastGroupId": 16, "lastNodeId": 191, "lastLinkId": 433, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image to Depth Map (Lotus)", "inputNode": {"id": -10, "bounding": [-60, -172.61268043518066, 126.625, 120]}, "outputNode": {"id": -20, "bounding": [1650, -172.61268043518066, 120, 60]}, "inputs": [{"id": "3bdd30c3-4ec9-485a-814b-e7d39fb6b5cc", "name": "pixels", "type": "IMAGE", "linkIds": [37], "localized_name": "pixels", "pos": [46.625, -152.61268043518066]}, {"id": "f9a1017c-f4b9-43b4-94c2-41c088b3a492", "name": "sigma", "type": "FLOAT", "linkIds": [243], "label": "depth_intensity", "pos": [46.625, -132.61268043518066]}, {"id": "374bfecc-34bb-47f9-82b6-cbe9383f8756", "name": "unet_name", "type": "COMBO", "linkIds": [423], "pos": [46.625, -112.61268043518066]}, {"id": "bb8707a1-46c3-44be-a15a-0adc908d871d", "name": "vae_name", "type": "COMBO", "linkIds": [424], "pos": [46.625, -92.61268043518066]}], "outputs": [{"id": "2ec278bd-0b66-4b30-9c5b-994d5f638214", "name": "IMAGE", "type": "IMAGE", "linkIds": [242], "localized_name": "IMAGE", "pos": [1670, -152.61268043518066]}], "widgets": [], "nodes": [{"id": 8, "type": "VAEDecode", "pos": [1380, -240], "size": [210, 46], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 232}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 240}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [35]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 10, "type": "UNETLoader", "pos": [135.34181213378906, -290.1947937011719], "size": [305.93701171875, 82], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 423}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [31, 241]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "UNETLoader", "models": [{"name": "lotus-depth-d-v1-1.safetensors", "url": "https://huggingface.co/Comfy-Org/lotus/resolve/main/lotus-depth-d-v1-1.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["lotus-depth-d-v1-1.safetensors", "default"]}, {"id": 14, "type": "VAELoader", "pos": [134.531494140625, -165.18197631835938], "size": [305.93701171875, 58], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 424}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [38, 240]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAELoader", "models": [{"name": "vae-ft-mse-840000-ema-pruned.safetensors", "url": "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["vae-ft-mse-840000-ema-pruned.safetensors"]}, {"id": 16, "type": "SamplerCustomAdvanced", "pos": [990.6585693359375, -319.9144287109375], "size": [355.20001220703125, 326], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 237}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 27}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 33}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 194}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 201}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "slot_index": 0, "links": [232]}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "slot_index": 1, "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "SamplerCustomAdvanced", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 18, "type": "DisableNoise", "pos": [730.47705078125, -320], "size": [210, 26], "flags": {}, "order": 0, "mode": 0, "inputs": [], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "slot_index": 0, "links": [237]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "DisableNoise", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 19, "type": "BasicGuider", "pos": [730.2631225585938, -251.22537231445312], "size": [210, 46], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 241}, {"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 238}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "slot_index": 0, "links": [27]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "BasicGuider", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 20, "type": "BasicScheduler", "pos": [488.64459228515625, -147.67201232910156], "size": [210, 106], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 31}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "slot_index": 0, "links": [66]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "BasicScheduler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["normal", 1, 1]}, {"id": 21, "type": "KSamplerSelect", "pos": [730.2631225585938, -161.22537231445312], "size": [210, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "slot_index": 0, "links": [33]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "KSamplerSelect", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["euler"]}, {"id": 22, "type": "ImageInvert", "pos": [1380, -310], "size": [210, 26], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 35}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [242]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "ImageInvert", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 23, "type": "VAEEncode", "pos": [730.2631225585938, 38.77463912963867], "size": [210, 46], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 37}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 38}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [201]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAEEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 28, "type": "SetFirstSigma", "pos": [730.2631225585938, -61.22536087036133], "size": [210, 58], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 66}, {"localized_name": "sigma", "name": "sigma", "type": "FLOAT", "widget": {"name": "sigma"}, "link": 243}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "slot_index": 0, "links": [194]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "SetFirstSigma", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": [999.0000000000002]}, {"id": 68, "type": "LotusConditioning", "pos": [490, -230], "size": [210, 26], "flags": {}, "order": 2, "mode": 0, "inputs": [], "outputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "slot_index": 0, "links": [238]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "LotusConditioning", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": []}], "groups": [{"id": 1, "title": "Load Models", "bounding": [120, -370, 335, 281.6000061035156], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 232, "origin_id": 16, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "LATENT"}, {"id": 240, "origin_id": 14, "origin_slot": 0, "target_id": 8, "target_slot": 1, "type": "VAE"}, {"id": 237, "origin_id": 18, "origin_slot": 0, "target_id": 16, "target_slot": 0, "type": "NOISE"}, {"id": 27, "origin_id": 19, "origin_slot": 0, "target_id": 16, "target_slot": 1, "type": "GUIDER"}, {"id": 33, "origin_id": 21, "origin_slot": 0, "target_id": 16, "target_slot": 2, "type": "SAMPLER"}, {"id": 194, "origin_id": 28, "origin_slot": 0, "target_id": 16, "target_slot": 3, "type": "SIGMAS"}, {"id": 201, "origin_id": 23, "origin_slot": 0, "target_id": 16, "target_slot": 4, "type": "LATENT"}, {"id": 241, "origin_id": 10, "origin_slot": 0, "target_id": 19, "target_slot": 0, "type": "MODEL"}, {"id": 238, "origin_id": 68, "origin_slot": 0, "target_id": 19, "target_slot": 1, "type": "CONDITIONING"}, {"id": 31, "origin_id": 10, "origin_slot": 0, "target_id": 20, "target_slot": 0, "type": "MODEL"}, {"id": 35, "origin_id": 8, "origin_slot": 0, "target_id": 22, "target_slot": 0, "type": "IMAGE"}, {"id": 38, "origin_id": 14, "origin_slot": 0, "target_id": 23, "target_slot": 1, "type": "VAE"}, {"id": 66, "origin_id": 20, "origin_slot": 0, "target_id": 28, "target_slot": 0, "type": "SIGMAS"}, {"id": 37, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 242, "origin_id": 22, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 243, "origin_id": -10, "origin_slot": 1, "target_id": 28, "target_slot": 1, "type": "FLOAT"}, {"id": 423, "origin_id": -10, "origin_slot": 2, "target_id": 10, "target_slot": 0, "type": "COMBO"}, {"id": 424, "origin_id": -10, "origin_slot": 3, "target_id": 14, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}}]}, "config": {}, "extra": {"ds": {"scale": 1.313181818181818, "offset": [271.9196871428176, -3845.0123774536323]}, "workflowRendererVersion": "LG"}, "version": 0.4} diff --git a/blueprints/Edge-Preserving Blur.json b/blueprints/Edge-Preserving Blur.json new file mode 100644 index 000000000..4f2416e9b --- /dev/null +++ b/blueprints/Edge-Preserving Blur.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 136, "last_link_id": 0, "nodes": [{"id": 136, "type": "c6dc0f88-416b-4db1-bed1-442d793de5ad", "pos": [669.0822222222221, 835.5507407407408], "size": [210, 106], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["130", "value"], ["131", "value"], ["133", "value"]]}, "widgets_values": [], "title": "Edge-Preserving Blur"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "c6dc0f88-416b-4db1-bed1-442d793de5ad", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 138, "lastLinkId": 109, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Edge-Preserving Blur", "inputNode": {"id": -10, "bounding": [1750, -620, 120, 60]}, "outputNode": {"id": -20, "bounding": [2700, -620, 120, 60]}, "inputs": [{"id": "06a6d0ad-25d7-4784-8c72-7fc8e7110a22", "name": "images.image0", "type": "IMAGE", "linkIds": [106], "localized_name": "images.image0", "label": "image", "pos": [1850, -600]}], "outputs": [{"id": "3ae9f5d7-be63-4c9f-9893-6f848defa377", "name": "IMAGE0", "type": "IMAGE", "linkIds": [99], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [2720, -600]}], "widgets": [], "nodes": [{"id": 128, "type": "GLSLShader", "pos": [2220, -860], "size": [420, 252], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 106}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 100}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 101}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": 107}, {"label": "u_int1", "localized_name": "ints.u_int1", "name": "ints.u_int1", "shape": 7, "type": "INT", "link": 103}, {"label": "u_int2", "localized_name": "ints.u_int2", "name": "ints.u_int2", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [99]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // Blur radius (0–20, default ~5)\nuniform float u_float1; // Edge threshold (0–100, default ~30)\nuniform int u_int0; // Step size (0/1 = every pixel, 2+ = skip pixels)\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int MAX_RADIUS = 20;\nconst float EPSILON = 0.0001;\n\n// Perceptual luminance\nfloat getLuminance(vec3 rgb) {\n return dot(rgb, vec3(0.299, 0.587, 0.114));\n}\n\nvec4 bilateralFilter(vec2 uv, vec2 texelSize, int radius,\n float sigmaSpatial, float sigmaColor)\n{\n vec4 center = texture(u_image0, uv);\n vec3 centerRGB = center.rgb;\n\n float invSpatial2 = -0.5 / (sigmaSpatial * sigmaSpatial);\n float invColor2 = -0.5 / (sigmaColor * sigmaColor + EPSILON);\n\n vec3 sumRGB = vec3(0.0);\n float sumWeight = 0.0;\n\n int step = max(u_int0, 1);\n float radius2 = float(radius * radius);\n\n for (int dy = -MAX_RADIUS; dy <= MAX_RADIUS; dy++) {\n if (dy < -radius || dy > radius) continue;\n if (abs(dy) % step != 0) continue;\n\n for (int dx = -MAX_RADIUS; dx <= MAX_RADIUS; dx++) {\n if (dx < -radius || dx > radius) continue;\n if (abs(dx) % step != 0) continue;\n\n vec2 offset = vec2(float(dx), float(dy));\n float dist2 = dot(offset, offset);\n if (dist2 > radius2) continue;\n\n vec3 sampleRGB = texture(u_image0, uv + offset * texelSize).rgb;\n\n // Spatial Gaussian\n float spatialWeight = exp(dist2 * invSpatial2);\n\n // Perceptual color distance (weighted RGB)\n vec3 diff = sampleRGB - centerRGB;\n float colorDist = dot(diff * diff, vec3(0.299, 0.587, 0.114));\n float colorWeight = exp(colorDist * invColor2);\n\n float w = spatialWeight * colorWeight;\n sumRGB += sampleRGB * w;\n sumWeight += w;\n }\n }\n\n vec3 resultRGB = sumRGB / max(sumWeight, EPSILON);\n return vec4(resultRGB, center.a); // preserve center alpha\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n\n float radiusF = clamp(u_float0, 0.0, float(MAX_RADIUS));\n int radius = int(radiusF + 0.5);\n\n if (radius == 0) {\n fragColor = texture(u_image0, v_texCoord);\n return;\n }\n\n // Edge threshold → color sigma\n // Squared curve for better low-end control\n float t = clamp(u_float1, 0.0, 100.0) / 100.0;\n t *= t;\n float sigmaColor = mix(0.01, 0.5, t);\n\n // Spatial sigma tied to radius\n float sigmaSpatial = max(radiusF * 0.75, 0.5);\n\n fragColor = bilateralFilter(\n v_texCoord,\n texelSize,\n radius,\n sigmaSpatial,\n sigmaColor\n );\n}", "from_input"]}, {"id": 130, "type": "PrimitiveFloat", "pos": [1930, -860], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "blur_radius", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [100]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 20, "step": 0.5, "precision": 1}, "widgets_values": [20]}, {"id": 131, "type": "PrimitiveFloat", "pos": [1930, -760], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "edge_threshold", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [101]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 100, "step": 1}, "widgets_values": [50]}, {"id": 133, "type": "PrimitiveInt", "pos": [1930, -660], "size": [270, 82], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "step_size", "localized_name": "value", "name": "value", "type": "INT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "INT", "name": "INT", "type": "INT", "links": [103, 107]}], "properties": {"Node name for S&R": "PrimitiveInt", "min": 0}, "widgets_values": [1, "fixed"]}], "groups": [], "links": [{"id": 100, "origin_id": 130, "origin_slot": 0, "target_id": 128, "target_slot": 2, "type": "FLOAT"}, {"id": 101, "origin_id": 131, "origin_slot": 0, "target_id": 128, "target_slot": 3, "type": "FLOAT"}, {"id": 107, "origin_id": 133, "origin_slot": 0, "target_id": 128, "target_slot": 5, "type": "INT"}, {"id": 103, "origin_id": 133, "origin_slot": 0, "target_id": 128, "target_slot": 6, "type": "INT"}, {"id": 106, "origin_id": -10, "origin_slot": 0, "target_id": 128, "target_slot": 0, "type": "IMAGE"}, {"id": 99, "origin_id": 128, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Blur"}]}, "extra": {}} diff --git a/blueprints/Film Grain.json b/blueprints/Film Grain.json new file mode 100644 index 000000000..b7ebe2a36 --- /dev/null +++ b/blueprints/Film Grain.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 22, "last_link_id": 0, "nodes": [{"id": 22, "type": "3324cf54-bcff-405f-a4bf-c5122c72fe56", "pos": [4800, -1180], "size": [250, 154], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Film Grain", "properties": {"proxyWidgets": [["17", "value"], ["18", "value"], ["19", "value"], ["20", "value"], ["21", "choice"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "3324cf54-bcff-405f-a4bf-c5122c72fe56", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 21, "lastLinkId": 30, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Film Grain", "inputNode": {"id": -10, "bounding": [4096.671470760602, -948.2184031393472, 120, 60]}, "outputNode": {"id": -20, "bounding": [4900, -948.2184031393472, 120, 60]}, "inputs": [{"id": "062968ea-da25-47e7-a180-d913c267f148", "name": "images.image0", "type": "IMAGE", "linkIds": [22], "localized_name": "images.image0", "label": "image", "pos": [4196.671470760602, -928.2184031393472]}], "outputs": [{"id": "43247d06-a39f-4733-9828-c39400fe02a4", "name": "IMAGE0", "type": "IMAGE", "linkIds": [23], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [4920, -928.2184031393472]}], "widgets": [], "nodes": [{"id": 15, "type": "GLSLShader", "pos": [4510, -1180], "size": [330, 272], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 22}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 26}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 27}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 28}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": 29}, {"label": "u_float4", "localized_name": "floats.u_float4", "name": "floats.u_float4", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": 30}, {"label": "u_int1", "localized_name": "ints.u_int1", "name": "ints.u_int1", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [23]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // grain amount [0.0 – 1.0] typical: 0.2–0.8\nuniform float u_float1; // grain size [0.3 – 3.0] lower = finer grain\nuniform float u_float2; // color amount [0.0 – 1.0] 0 = monochrome, 1 = RGB grain\nuniform float u_float3; // luminance bias [0.0 – 1.0] 0 = uniform, 1 = shadows only\nuniform int u_int0; // noise mode [0 or 1] 0 = smooth, 1 = grainy\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\n// High-quality integer hash (pcg-like)\nuint pcg(uint v) {\n uint state = v * 747796405u + 2891336453u;\n uint word = ((state >> ((state >> 28u) + 4u)) ^ state) * 277803737u;\n return (word >> 22u) ^ word;\n}\n\n// 2D -> 1D hash input\nuint hash2d(uvec2 p) {\n return pcg(p.x + pcg(p.y));\n}\n\n// Hash to float [0, 1]\nfloat hashf(uvec2 p) {\n return float(hash2d(p)) / float(0xffffffffu);\n}\n\n// Hash to float with offset (for RGB channels)\nfloat hashf(uvec2 p, uint offset) {\n return float(pcg(hash2d(p) + offset)) / float(0xffffffffu);\n}\n\n// Convert uniform [0,1] to roughly Gaussian distribution\n// Using simple approximation: average of multiple samples\nfloat toGaussian(uvec2 p) {\n float sum = hashf(p, 0u) + hashf(p, 1u) + hashf(p, 2u) + hashf(p, 3u);\n return (sum - 2.0) * 0.7; // Centered, scaled\n}\n\nfloat toGaussian(uvec2 p, uint offset) {\n float sum = hashf(p, offset) + hashf(p, offset + 1u) \n + hashf(p, offset + 2u) + hashf(p, offset + 3u);\n return (sum - 2.0) * 0.7;\n}\n\n// Smooth noise with better interpolation\nfloat smoothNoise(vec2 p) {\n vec2 i = floor(p);\n vec2 f = fract(p);\n \n // Quintic interpolation (less banding than cubic)\n f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);\n \n uvec2 ui = uvec2(i);\n float a = toGaussian(ui);\n float b = toGaussian(ui + uvec2(1u, 0u));\n float c = toGaussian(ui + uvec2(0u, 1u));\n float d = toGaussian(ui + uvec2(1u, 1u));\n \n return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);\n}\n\nfloat smoothNoise(vec2 p, uint offset) {\n vec2 i = floor(p);\n vec2 f = fract(p);\n \n f = f * f * f * (f * (f * 6.0 - 15.0) + 10.0);\n \n uvec2 ui = uvec2(i);\n float a = toGaussian(ui, offset);\n float b = toGaussian(ui + uvec2(1u, 0u), offset);\n float c = toGaussian(ui + uvec2(0u, 1u), offset);\n float d = toGaussian(ui + uvec2(1u, 1u), offset);\n \n return mix(mix(a, b, f.x), mix(c, d, f.x), f.y);\n}\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n \n // Luminance (Rec.709)\n float luma = dot(color.rgb, vec3(0.2126, 0.7152, 0.0722));\n \n // Grain UV (resolution-independent)\n vec2 grainUV = v_texCoord * u_resolution / max(u_float1, 0.01);\n uvec2 grainPixel = uvec2(grainUV);\n \n float g;\n vec3 grainRGB;\n \n if (u_int0 == 1) {\n // Grainy mode: pure hash noise (no interpolation = no banding)\n g = toGaussian(grainPixel);\n grainRGB = vec3(\n toGaussian(grainPixel, 100u),\n toGaussian(grainPixel, 200u),\n toGaussian(grainPixel, 300u)\n );\n } else {\n // Smooth mode: interpolated with quintic curve\n g = smoothNoise(grainUV);\n grainRGB = vec3(\n smoothNoise(grainUV, 100u),\n smoothNoise(grainUV, 200u),\n smoothNoise(grainUV, 300u)\n );\n }\n \n // Luminance weighting (less grain in highlights)\n float lumWeight = mix(1.0, 1.0 - luma, clamp(u_float3, 0.0, 1.0));\n \n // Strength\n float strength = u_float0 * 0.15;\n \n // Color vs monochrome grain\n vec3 grainColor = mix(vec3(g), grainRGB, clamp(u_float2, 0.0, 1.0));\n \n color.rgb += grainColor * strength * lumWeight;\n fragColor0 = vec4(clamp(color.rgb, 0.0, 1.0), color.a);\n}\n", "from_input"]}, {"id": 21, "type": "CustomCombo", "pos": [4280, -780], "size": [210, 153.8888931274414], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "grain_mode", "localized_name": "choice", "name": "choice", "type": "COMBO", "widget": {"name": "choice"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": null}, {"localized_name": "INDEX", "name": "INDEX", "type": "INT", "links": [30]}], "properties": {"Node name for S&R": "CustomCombo"}, "widgets_values": ["Smooth", 0, "Smooth", "Grainy", ""]}, {"id": 17, "type": "PrimitiveFloat", "pos": [4276.671470760602, -1180.3256994061358], "size": [210, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "grain_amount", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [26]}], "title": "Grain amount", "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 1, "step": 0.05, "precision": 2}, "widgets_values": [0.25]}, {"id": 18, "type": "PrimitiveFloat", "pos": [4280, -1080], "size": [210, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "grain_size", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [27]}], "title": "Grain size", "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0.05, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.1]}, {"id": 19, "type": "PrimitiveFloat", "pos": [4280, -980], "size": [210, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "color_amount", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [28]}], "title": "Color amount", "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 1, "precision": 2, "step": 0.05}, "widgets_values": [0]}, {"id": 20, "type": "PrimitiveFloat", "pos": [4280, -880], "size": [210, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "shadow_focus", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [29]}], "title": "Luminance bias", "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 1, "precision": 2, "step": 0.05}, "widgets_values": [0]}], "groups": [], "links": [{"id": 26, "origin_id": 17, "origin_slot": 0, "target_id": 15, "target_slot": 2, "type": "FLOAT"}, {"id": 27, "origin_id": 18, "origin_slot": 0, "target_id": 15, "target_slot": 3, "type": "FLOAT"}, {"id": 28, "origin_id": 19, "origin_slot": 0, "target_id": 15, "target_slot": 4, "type": "FLOAT"}, {"id": 29, "origin_id": 20, "origin_slot": 0, "target_id": 15, "target_slot": 5, "type": "FLOAT"}, {"id": 30, "origin_id": 21, "origin_slot": 1, "target_id": 15, "target_slot": 7, "type": "INT"}, {"id": 22, "origin_id": -10, "origin_slot": 0, "target_id": 15, "target_slot": 0, "type": "IMAGE"}, {"id": 23, "origin_id": 15, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} diff --git a/blueprints/Glow.json b/blueprints/Glow.json new file mode 100644 index 000000000..590445c06 --- /dev/null +++ b/blueprints/Glow.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 37, "last_link_id": 0, "nodes": [{"id": 37, "type": "0a99445a-aaf8-4a7f-aec3-d7d710ae1495", "pos": [2160, -360], "size": [260, 154], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["34", "value"], ["35", "value"], ["33", "value"], ["31", "choice"], ["32", "color"]]}, "widgets_values": [], "title": "Glow"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "0a99445a-aaf8-4a7f-aec3-d7d710ae1495", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 36, "lastLinkId": 53, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Glow", "inputNode": {"id": -10, "bounding": [2110, -165, 120, 60]}, "outputNode": {"id": -20, "bounding": [3170, -165, 120, 60]}, "inputs": [{"id": "ffc7cf94-be90-4d56-a3b8-d0514d61c015", "name": "images.image0", "type": "IMAGE", "linkIds": [45], "localized_name": "images.image0", "label": "image", "pos": [2210, -145]}], "outputs": [{"id": "04986101-50be-4762-8957-8e2a5e460bbb", "name": "IMAGE0", "type": "IMAGE", "linkIds": [53], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [3190, -145]}], "widgets": [], "nodes": [{"id": 30, "type": "GLSLShader", "pos": [2590, -520], "size": [520, 272], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 45}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 51}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 50}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 52}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": 46}, {"label": "u_int1", "localized_name": "ints.u_int1", "name": "ints.u_int1", "shape": 7, "type": "INT", "link": 47}, {"label": "u_int2", "localized_name": "ints.u_int2", "name": "ints.u_int2", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [53]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / u_resolution;\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}", "from_input"]}, {"id": 34, "type": "PrimitiveFloat", "pos": [2290, -510], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "intensity", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [51]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 100, "precision": 1, "step": 1}, "widgets_values": [30]}, {"id": 35, "type": "PrimitiveFloat", "pos": [2290, -410], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "radius", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [50]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 100, "precision": 1, "step": 1}, "widgets_values": [25]}, {"id": 33, "type": "PrimitiveFloat", "pos": [2290, -310], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "threshold", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [52]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 100, "precision": 1, "step": 1}, "widgets_values": [100]}, {"id": 32, "type": "ColorToRGBInt", "pos": [2290, -210], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "color_tint", "localized_name": "color", "name": "color", "type": "COLOR", "widget": {"name": "color"}, "link": null}], "outputs": [{"localized_name": "rgb_int", "name": "rgb_int", "type": "INT", "links": [47]}], "properties": {"Node name for S&R": "ColorToRGBInt"}, "widgets_values": ["#45edf5"]}, {"id": 31, "type": "CustomCombo", "pos": [2290, -110], "size": [270, 222], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "blend_mode", "localized_name": "choice", "name": "choice", "type": "COMBO", "widget": {"name": "choice"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": null}, {"localized_name": "INDEX", "name": "INDEX", "type": "INT", "links": [46]}], "properties": {"Node name for S&R": "CustomCombo"}, "widgets_values": ["add", 0, "add", "screen", "soft", "overlay", "lighten", ""]}], "groups": [], "links": [{"id": 51, "origin_id": 34, "origin_slot": 0, "target_id": 30, "target_slot": 2, "type": "FLOAT"}, {"id": 50, "origin_id": 35, "origin_slot": 0, "target_id": 30, "target_slot": 3, "type": "FLOAT"}, {"id": 52, "origin_id": 33, "origin_slot": 0, "target_id": 30, "target_slot": 4, "type": "FLOAT"}, {"id": 46, "origin_id": 31, "origin_slot": 1, "target_id": 30, "target_slot": 6, "type": "INT"}, {"id": 47, "origin_id": 32, "origin_slot": 0, "target_id": 30, "target_slot": 7, "type": "INT"}, {"id": 45, "origin_id": -10, "origin_slot": 0, "target_id": 30, "target_slot": 0, "type": "IMAGE"}, {"id": 53, "origin_id": 30, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} diff --git a/blueprints/Hue and Saturation.json b/blueprints/Hue and Saturation.json new file mode 100644 index 000000000..04846c51d --- /dev/null +++ b/blueprints/Hue and Saturation.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 11, "last_link_id": 0, "nodes": [{"id": 11, "type": "c64f83e9-aa5d-4031-89f1-0704e39299fe", "pos": [870, -220], "size": [250, 178], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Hue and Saturation", "properties": {"proxyWidgets": [["2", "choice"], ["4", "value"], ["5", "value"], ["6", "value"], ["7", "value"], ["3", "choice"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "c64f83e9-aa5d-4031-89f1-0704e39299fe", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 10, "lastLinkId": 11, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Hue and Saturation", "inputNode": {"id": -10, "bounding": [360, -176, 120, 60]}, "outputNode": {"id": -20, "bounding": [1410, -176, 120, 60]}, "inputs": [{"id": "a5aae7ea-b511-4045-b5da-94101e269cd7", "name": "images.image0", "type": "IMAGE", "linkIds": [10], "localized_name": "images.image0", "label": "image", "pos": [460, -156]}], "outputs": [{"id": "30b72604-69b3-4944-b253-a9099bbd73a9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [8], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [1430, -156]}], "widgets": [], "nodes": [{"id": 3, "type": "CustomCombo", "pos": [540, -240], "size": [270, 150], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "color_space", "localized_name": "choice", "name": "choice", "type": "COMBO", "widget": {"name": "choice"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": null}, {"localized_name": "INDEX", "name": "INDEX", "type": "INT", "links": [2]}], "properties": {"Node name for S&R": "CustomCombo"}, "widgets_values": ["HSL", 0, "HSL", "HSB/HSV", ""]}, {"id": 2, "type": "CustomCombo", "pos": [540, -580], "size": [270, 294], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "mode", "localized_name": "choice", "name": "choice", "type": "COMBO", "widget": {"name": "choice"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": null}, {"localized_name": "INDEX", "name": "INDEX", "type": "INT", "links": [1]}], "properties": {"Node name for S&R": "CustomCombo"}, "widgets_values": ["Master", 0, "Master", "Reds", "Yellows", "Greens", "Cyans", "Blues", "Magentas", "Colorize", ""]}, {"id": 7, "type": "PrimitiveFloat", "pos": [540, 260], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "overlap", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [6]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 100, "precision": 1, "step": 1}, "widgets_values": [50]}, {"id": 6, "type": "PrimitiveFloat", "pos": [540, 160], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "brightness", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [5]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": -100, "max": 100, "precision": 1, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 0, 0]}, {"offset": 1, "color": [255, 255, 255]}]}, "widgets_values": [0]}, {"id": 5, "type": "PrimitiveFloat", "pos": [540, 60], "size": [270, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "saturation", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [4]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": -100, "max": 100, "precision": 1, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 4, "type": "PrimitiveFloat", "pos": [540, -40], "size": [270, 58], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "hue", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [3]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": -180, "max": 180, "precision": 1, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [255, 0, 0]}, {"offset": 0.16666666666666666, "color": [255, 255, 0]}, {"offset": 0.3333333333333333, "color": [0, 255, 0]}, {"offset": 0.5, "color": [0, 255, 255]}, {"offset": 0.6666666666666666, "color": [0, 0, 255]}, {"offset": 0.8333333333333334, "color": [255, 0, 255]}, {"offset": 1, "color": [255, 0, 0]}]}, "widgets_values": [0]}, {"id": 1, "type": "GLSLShader", "pos": [880, -300], "size": [470, 292], "flags": {}, "order": 6, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 10}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 3}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 4}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 5}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": 6}, {"label": "u_float4", "localized_name": "floats.u_float4", "name": "floats.u_float4", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": 1}, {"label": "u_int1", "localized_name": "ints.u_int1", "name": "ints.u_int1", "shape": 7, "type": "INT", "link": 2}, {"label": "u_int2", "localized_name": "ints.u_int2", "name": "ints.u_int2", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [8]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Mode: 0=Master, 1=Reds, 2=Yellows, 3=Greens, 4=Cyans, 5=Blues, 6=Magentas, 7=Colorize\nuniform int u_int1; // Color Space: 0=HSL, 1=HSB/HSV\nuniform float u_float0; // Hue (-180 to 180)\nuniform float u_float1; // Saturation (-100 to 100)\nuniform float u_float2; // Lightness/Brightness (-100 to 100)\nuniform float u_float3; // Overlap (0 to 100) - feathering between adjacent color ranges\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\n// Color range modes\nconst int MODE_MASTER = 0;\nconst int MODE_RED = 1;\nconst int MODE_YELLOW = 2;\nconst int MODE_GREEN = 3;\nconst int MODE_CYAN = 4;\nconst int MODE_BLUE = 5;\nconst int MODE_MAGENTA = 6;\nconst int MODE_COLORIZE = 7;\n\n// Color space modes\nconst int COLORSPACE_HSL = 0;\nconst int COLORSPACE_HSB = 1;\n\nconst float EPSILON = 0.0001;\n\n//=============================================================================\n// RGB <-> HSL Conversions\n//=============================================================================\n\nvec3 rgb2hsl(vec3 c) {\n float maxC = max(max(c.r, c.g), c.b);\n float minC = min(min(c.r, c.g), c.b);\n float delta = maxC - minC;\n\n float h = 0.0;\n float s = 0.0;\n float l = (maxC + minC) * 0.5;\n\n if (delta > EPSILON) {\n s = l < 0.5\n ? delta / (maxC + minC)\n : delta / (2.0 - maxC - minC);\n\n if (maxC == c.r) {\n h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);\n } else if (maxC == c.g) {\n h = (c.b - c.r) / delta + 2.0;\n } else {\n h = (c.r - c.g) / delta + 4.0;\n }\n h /= 6.0;\n }\n\n return vec3(h, s, l);\n}\n\nfloat hue2rgb(float p, float q, float t) {\n t = fract(t);\n if (t < 1.0/6.0) return p + (q - p) * 6.0 * t;\n if (t < 0.5) return q;\n if (t < 2.0/3.0) return p + (q - p) * (2.0/3.0 - t) * 6.0;\n return p;\n}\n\nvec3 hsl2rgb(vec3 hsl) {\n if (hsl.y < EPSILON) return vec3(hsl.z);\n\n float q = hsl.z < 0.5\n ? hsl.z * (1.0 + hsl.y)\n : hsl.z + hsl.y - hsl.z * hsl.y;\n float p = 2.0 * hsl.z - q;\n\n return vec3(\n hue2rgb(p, q, hsl.x + 1.0/3.0),\n hue2rgb(p, q, hsl.x),\n hue2rgb(p, q, hsl.x - 1.0/3.0)\n );\n}\n\nvec3 rgb2hsb(vec3 c) {\n float maxC = max(max(c.r, c.g), c.b);\n float minC = min(min(c.r, c.g), c.b);\n float delta = maxC - minC;\n\n float h = 0.0;\n float s = (maxC > EPSILON) ? delta / maxC : 0.0;\n float b = maxC;\n\n if (delta > EPSILON) {\n if (maxC == c.r) {\n h = (c.g - c.b) / delta + (c.g < c.b ? 6.0 : 0.0);\n } else if (maxC == c.g) {\n h = (c.b - c.r) / delta + 2.0;\n } else {\n h = (c.r - c.g) / delta + 4.0;\n }\n h /= 6.0;\n }\n\n return vec3(h, s, b);\n}\n\nvec3 hsb2rgb(vec3 hsb) {\n vec3 rgb = clamp(abs(mod(hsb.x * 6.0 + vec3(0.0, 4.0, 2.0), 6.0) - 3.0) - 1.0, 0.0, 1.0);\n return hsb.z * mix(vec3(1.0), rgb, hsb.y);\n}\n\n//=============================================================================\n// Color Range Weight Calculation\n//=============================================================================\n\nfloat hueDistance(float a, float b) {\n float d = abs(a - b);\n return min(d, 1.0 - d);\n}\n\nfloat getHueWeight(float hue, float center, float overlap) {\n float baseWidth = 1.0 / 6.0;\n float feather = baseWidth * overlap;\n\n float d = hueDistance(hue, center);\n\n float inner = baseWidth * 0.5;\n float outer = inner + feather;\n\n return 1.0 - smoothstep(inner, outer, d);\n}\n\nfloat getModeWeight(float hue, int mode, float overlap) {\n if (mode == MODE_MASTER || mode == MODE_COLORIZE) return 1.0;\n\n if (mode == MODE_RED) {\n return max(\n getHueWeight(hue, 0.0, overlap),\n getHueWeight(hue, 1.0, overlap)\n );\n }\n\n float center = float(mode - 1) / 6.0;\n return getHueWeight(hue, center, overlap);\n}\n\n//=============================================================================\n// Adjustment Functions\n//=============================================================================\n\nfloat adjustLightness(float l, float amount) {\n return amount > 0.0\n ? l + (1.0 - l) * amount\n : l + l * amount;\n}\n\nfloat adjustBrightness(float b, float amount) {\n return clamp(b + amount, 0.0, 1.0);\n}\n\nfloat adjustSaturation(float s, float amount) {\n return amount > 0.0\n ? s + (1.0 - s) * amount\n : s + s * amount;\n}\n\nvec3 colorize(vec3 rgb, float hue, float sat, float light) {\n float lum = dot(rgb, vec3(0.299, 0.587, 0.114));\n float l = adjustLightness(lum, light);\n\n vec3 hsl = vec3(fract(hue), clamp(sat, 0.0, 1.0), clamp(l, 0.0, 1.0));\n return hsl2rgb(hsl);\n}\n\n//=============================================================================\n// Main\n//=============================================================================\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n\n float hueShift = u_float0 / 360.0; // -180..180 -> -0.5..0.5\n float satAmount = u_float1 / 100.0; // -100..100 -> -1..1\n float lightAmount= u_float2 / 100.0; // -100..100 -> -1..1\n float overlap = u_float3 / 100.0; // 0..100 -> 0..1\n\n vec3 result;\n\n if (u_int0 == MODE_COLORIZE) {\n result = colorize(original.rgb, hueShift, satAmount, lightAmount);\n fragColor = vec4(result, original.a);\n return;\n }\n\n vec3 hsx = (u_int1 == COLORSPACE_HSL)\n ? rgb2hsl(original.rgb)\n : rgb2hsb(original.rgb);\n\n float weight = getModeWeight(hsx.x, u_int0, overlap);\n\n if (u_int0 != MODE_MASTER && hsx.y < EPSILON) {\n weight = 0.0;\n }\n\n if (weight > EPSILON) {\n float h = fract(hsx.x + hueShift * weight);\n float s = clamp(adjustSaturation(hsx.y, satAmount * weight), 0.0, 1.0);\n float v = (u_int1 == COLORSPACE_HSL)\n ? clamp(adjustLightness(hsx.z, lightAmount * weight), 0.0, 1.0)\n : clamp(adjustBrightness(hsx.z, lightAmount * weight), 0.0, 1.0);\n\n vec3 adjusted = vec3(h, s, v);\n result = (u_int1 == COLORSPACE_HSL)\n ? hsl2rgb(adjusted)\n : hsb2rgb(adjusted);\n } else {\n result = original.rgb;\n }\n\n fragColor = vec4(result, original.a);\n}\n", "from_input"]}], "groups": [], "links": [{"id": 3, "origin_id": 4, "origin_slot": 0, "target_id": 1, "target_slot": 2, "type": "FLOAT"}, {"id": 4, "origin_id": 5, "origin_slot": 0, "target_id": 1, "target_slot": 3, "type": "FLOAT"}, {"id": 5, "origin_id": 6, "origin_slot": 0, "target_id": 1, "target_slot": 4, "type": "FLOAT"}, {"id": 6, "origin_id": 7, "origin_slot": 0, "target_id": 1, "target_slot": 5, "type": "FLOAT"}, {"id": 1, "origin_id": 2, "origin_slot": 1, "target_id": 1, "target_slot": 7, "type": "INT"}, {"id": 2, "origin_id": 3, "origin_slot": 1, "target_id": 1, "target_slot": 8, "type": "INT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 1, "target_slot": 0, "type": "IMAGE"}, {"id": 8, "origin_id": 1, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} diff --git a/blueprints/Image Blur.json b/blueprints/Image Blur.json new file mode 100644 index 000000000..4b9e74255 --- /dev/null +++ b/blueprints/Image Blur.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 8, "last_link_id": 0, "nodes": [{"id": 8, "type": "198632a3-ee76-4aab-9ce7-a69c624eaff9", "pos": [4470, -1840], "size": [210, 82], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "blurred_image", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["12", "choice"], ["10", "value"]]}, "widgets_values": [], "title": "Image Blur"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "198632a3-ee76-4aab-9ce7-a69c624eaff9", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 12, "lastLinkId": 11, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Blur", "inputNode": {"id": -10, "bounding": [3540, -2445, 120, 60]}, "outputNode": {"id": -20, "bounding": [4620, -2445, 121.11666870117188, 60]}, "inputs": [{"id": "7ff2a402-6b11-45e8-a92a-7158d216520a", "name": "images.image0", "type": "IMAGE", "linkIds": [9], "localized_name": "images.image0", "label": "image", "pos": [3640, -2425]}], "outputs": [{"id": "80a8e19e-ffd9-44a5-90f2-710815a5b063", "name": "IMAGE0", "type": "IMAGE", "linkIds": [3], "localized_name": "IMAGE0", "label": "blurred_image", "pos": [4640, -2425]}], "widgets": [], "nodes": [{"id": 12, "type": "CustomCombo", "pos": [3720, -2620], "size": [270, 174], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "blur_type", "localized_name": "choice", "name": "choice", "type": "COMBO", "widget": {"name": "choice"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": null}, {"localized_name": "INDEX", "name": "INDEX", "type": "INT", "links": [11]}], "properties": {"Node name for S&R": "CustomCombo"}, "widgets_values": ["Gaussian", 0, "Gaussian", "Box", "Radial", ""]}, {"id": 10, "type": "PrimitiveFloat", "pos": [4020, -2780], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [10]}], "properties": {"Node name for S&R": "PrimitiveFloat", "max": 100, "min": 0}, "widgets_values": [20]}, {"id": 1, "type": "GLSLShader", "pos": [4020, -2670], "size": [430, 212], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 9}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 10}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": 11}, {"label": "u_int1", "localized_name": "ints.u_int1", "name": "ints.u_int1", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [3]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / u_resolution;\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n", "from_input"]}], "groups": [], "links": [{"id": 10, "origin_id": 10, "origin_slot": 0, "target_id": 1, "target_slot": 2, "type": "FLOAT"}, {"id": 11, "origin_id": 12, "origin_slot": 1, "target_id": 1, "target_slot": 4, "type": "INT"}, {"id": 9, "origin_id": -10, "origin_slot": 0, "target_id": 1, "target_slot": 0, "type": "IMAGE"}, {"id": 3, "origin_id": 1, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Blur"}]}} diff --git a/blueprints/Image Captioning (gemini).json b/blueprints/Image Captioning (gemini).json new file mode 100644 index 000000000..89ebac802 --- /dev/null +++ b/blueprints/Image Captioning (gemini).json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 231, "last_link_id": 0, "nodes": [{"id": 231, "type": "e3e78497-720e-45a2-b4fb-c7bfdb80dd11", "pos": [23.13283014087665, 1034.468391137315], "size": [280, 260], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": null}, {"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": []}], "properties": {"proxyWidgets": [["-1", "prompt"], ["-1", "model"], ["1", "seed"]], "cnr_id": "comfy-core", "ver": "0.13.0"}, "widgets_values": ["Describe this image", "gemini-2.5-pro"], "title": "Image Captioning(Gemini)"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "e3e78497-720e-45a2-b4fb-c7bfdb80dd11", "version": 1, "state": {"lastGroupId": 1, "lastNodeId": 16, "lastLinkId": 16, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Captioning(Gemini)", "inputNode": {"id": -10, "bounding": [-6870, 2530, 120, 100]}, "outputNode": {"id": -20, "bounding": [-6240, 2530, 120, 60]}, "inputs": [{"id": "97cb8fa5-0514-4e05-b206-46fa6d7b5589", "name": "images", "type": "IMAGE", "linkIds": [1], "localized_name": "images", "shape": 7, "pos": [-6770, 2550]}, {"id": "d8cbd7eb-636a-4d7b-8ff6-b22f1755e26c", "name": "prompt", "type": "STRING", "linkIds": [15], "pos": [-6770, 2570]}, {"id": "b034e26a-d114-4604-aec2-32783e86aa6b", "name": "model", "type": "COMBO", "linkIds": [16], "pos": [-6770, 2590]}], "outputs": [{"id": "e12c6e80-5210-4328-a581-bc8924c53070", "name": "STRING", "type": "STRING", "linkIds": [6], "localized_name": "STRING", "pos": [-6220, 2550]}], "widgets": [], "nodes": [{"id": 1, "type": "GeminiNode", "pos": [-6690, 2360], "size": [390, 430], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 1}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 15}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": 16}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [6]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["Describe this image", "gemini-2.5-pro", 511865409297955, "randomize", "- Role: AI Image Analysis and Description Specialist\n- Background: The user requires a prompt that enables AI to analyze images and generate detailed descriptions which can be used as drawing prompts to create similar images. This is essential for tasks like content creation, design inspiration, and artistic exploration.\n- Profile: As an AI Image Analysis and Description Specialist, you possess extensive knowledge in computer vision, image processing, and natural language generation. You are adept at interpreting visual data and translating it into descriptive text that can guide the creation of new images.\n- Skills: Proficiency in image recognition, feature extraction, descriptive language generation, and understanding of artistic elements such as composition, color, and texture.\n- Goals: To analyze the provided image, generate a comprehensive and detailed description that captures the key visual elements, and ensure this description can effectively serve as a drawing prompt for creating similar images.\n- Constrains: The description must be clear, concise, and specific enough to guide the creation of a similar image. It should avoid ambiguity and focus on the most salient features of the image. The output should only contain the drawing prompt.\n- OutputFormat: A detailed text description of the image, highlighting key visual elements such as objects, colors, composition, and any unique features.\n- Workflow:\n 1. Analyze the image to identify key visual elements including objects, colors, and composition.\n 2. Generate a detailed description that captures the essence of the image, ensuring it is specific and actionable.\n 3. Refine the description to ensure clarity and conciseness, making it suitable for use as a drawing prompt."], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 1, "origin_id": -10, "origin_slot": 0, "target_id": 1, "target_slot": 0, "type": "IMAGE"}, {"id": 6, "origin_id": 1, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "*"}, {"id": 15, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 4, "type": "STRING"}, {"id": 16, "origin_id": -10, "origin_slot": 2, "target_id": 1, "target_slot": 5, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Image Captioning"}]}} diff --git a/blueprints/Image Channels.json b/blueprints/Image Channels.json new file mode 100644 index 000000000..cb3488883 --- /dev/null +++ b/blueprints/Image Channels.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 29, "last_link_id": 0, "nodes": [{"id": 29, "type": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "pos": [1970, -230], "size": [180, 86], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": []}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": []}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": []}], "title": "Image Channels", "properties": {"proxyWidgets": []}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "4c9d6ea4-b912-40e5-8766-6793a9758c53", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 28, "lastLinkId": 39, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Channels", "inputNode": {"id": -10, "bounding": [1820, -185, 120, 60]}, "outputNode": {"id": -20, "bounding": [2460, -215, 120, 120]}, "inputs": [{"id": "3522932b-2d86-4a1f-a02a-cb29f3a9d7fe", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image", "pos": [1920, -165]}], "outputs": [{"id": "605cb9c3-b065-4d9b-81d2-3ec331889b2b", "name": "IMAGE0", "type": "IMAGE", "linkIds": [26], "localized_name": "IMAGE0", "label": "R", "pos": [2480, -195]}, {"id": "fb44a77e-0522-43e9-9527-82e7465b3596", "name": "IMAGE1", "type": "IMAGE", "linkIds": [27], "localized_name": "IMAGE1", "label": "G", "pos": [2480, -175]}, {"id": "81460ee6-0131-402a-874f-6bf3001fc4ff", "name": "IMAGE2", "type": "IMAGE", "linkIds": [28], "localized_name": "IMAGE2", "label": "B", "pos": [2480, -155]}, {"id": "ae690246-80d4-4951-b1d9-9306d8a77417", "name": "IMAGE3", "type": "IMAGE", "linkIds": [29], "localized_name": "IMAGE3", "label": "A", "pos": [2480, -135]}], "widgets": [], "nodes": [{"id": 23, "type": "GLSLShader", "pos": [2000, -330], "size": [400, 172], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}], "outputs": [{"label": "R", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [26]}, {"label": "G", "localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": [27]}, {"label": "B", "localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": [28]}, {"label": "A", "localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": [29]}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\nlayout(location = 1) out vec4 fragColor1;\nlayout(location = 2) out vec4 fragColor2;\nlayout(location = 3) out vec4 fragColor3;\n\nvoid main() {\n vec4 color = texture(u_image0, v_texCoord);\n // Output each channel as grayscale to separate render targets\n fragColor0 = vec4(vec3(color.r), 1.0); // Red channel\n fragColor1 = vec4(vec3(color.g), 1.0); // Green channel\n fragColor2 = vec4(vec3(color.b), 1.0); // Blue channel\n fragColor3 = vec4(vec3(color.a), 1.0); // Alpha channel\n}\n", "from_input"]}], "groups": [], "links": [{"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 27, "origin_id": 23, "origin_slot": 1, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 28, "origin_id": 23, "origin_slot": 2, "target_id": -20, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 23, "origin_slot": 3, "target_id": -20, "target_slot": 3, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}} diff --git a/blueprints/Image Edit (Flux.2 Klein 4B).json b/blueprints/Image Edit (Flux.2 Klein 4B).json new file mode 100644 index 000000000..c87c7e122 --- /dev/null +++ b/blueprints/Image Edit (Flux.2 Klein 4B).json @@ -0,0 +1 @@ +{"id": "6686cb78-8003-4289-b969-929755e9a84d", "revision": 0, "last_node_id": 81, "last_link_id": 179, "nodes": [{"id": 75, "type": "7b34ab90-36f9-45ba-a665-71d418f0df18", "pos": [311.66672468419983, 830], "size": [400, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "prompt", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"name": "image", "type": "IMAGE", "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["-1", "text"], ["73", "noise_seed"], ["73", "control_after_generate"], ["-1", "unet_name"], ["-1", "clip_name"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.8.2", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["", null, null, "flux-2-klein-base-4b-fp8.safetensors", "qwen_3_4b.safetensors", "flux2-vae.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "7b34ab90-36f9-45ba-a665-71d418f0df18", "version": 1, "state": {"lastGroupId": 4, "lastNodeId": 81, "lastLinkId": 179, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Image Edit (Flux.2 Klein 4B)", "inputNode": {"id": -10, "bounding": [-576.3333463986639, 559.0277780034634, 120, 140]}, "outputNode": {"id": -20, "bounding": [1373.6666536013363, 549.0277780034634, 120, 60]}, "inputs": [{"id": "7061147a-fb75-450d-8e97-c8be594a8e16", "name": "text", "type": "STRING", "linkIds": [162], "label": "prompt", "pos": [-476.33334639866393, 579.0277780034634]}, {"id": "68629112-b7b0-41ce-8912-23adad00d3db", "name": "image", "type": "IMAGE", "linkIds": [175], "pos": [-476.33334639866393, 599.0277780034634]}, {"id": "006f0b42-cb11-4484-8b7e-c34a9fb12824", "name": "unet_name", "type": "COMBO", "linkIds": [177], "pos": [-476.33334639866393, 619.0277780034634]}, {"id": "0083499c-8e83-4974-a587-ba6e89e36acc", "name": "clip_name", "type": "COMBO", "linkIds": [178], "pos": [-476.33334639866393, 639.0277780034634]}, {"id": "7c95e27c-7920-43d5-a0ac-c6570653f5da", "name": "vae_name", "type": "COMBO", "linkIds": [179], "pos": [-476.33334639866393, 659.0277780034634]}], "outputs": [{"id": "c5e7966d-07ed-4c9a-ad89-9d378a41ea7b", "name": "IMAGE", "type": "IMAGE", "linkIds": [153], "localized_name": "IMAGE", "pos": [1393.6666536013363, 569.0277780034634]}], "widgets": [], "nodes": [{"id": 61, "type": "KSamplerSelect", "pos": [560, 460], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "links": [144]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "KSamplerSelect", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["euler"]}, {"id": 62, "type": "Flux2Scheduler", "pos": [560, 560], "size": [270, 106], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 171}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 173}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "links": [145]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "Flux2Scheduler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [20, 1024, 1024]}, {"id": 63, "type": "CFGGuider", "pos": [560, 320], "size": [270, 98], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 139}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 167}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 168}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "links": [143]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "CFGGuider", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [5]}, {"id": 65, "type": "VAEDecode", "pos": [1093.6666007601261, 154.02777277882814], "size": [220, 46], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 147}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 148}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [153]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 70, "type": "UNETLoader", "pos": [-386.3333318901398, 203.8611174586574], "size": [364.42708333333337, 82], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 177}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [139]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "UNETLoader", "models": [{"name": "flux-2-klein-base-4b-fp8.safetensors", "url": "https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4b-fp8/resolve/main/flux-2-klein-base-4b-fp8.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["flux-2-klein-base-4b-fp8.safetensors", "default"]}, {"id": 71, "type": "CLIPLoader", "pos": [-386.3333318901398, 353.8611341117752], "size": [364.42708333333337, 106], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 178}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [151, 152]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "CLIPLoader", "models": [{"name": "qwen_3_4b.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/text_encoders/qwen_3_4b.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_3_4b.safetensors", "flux2", "default"]}, {"id": 74, "type": "CLIPTextEncode", "pos": [43.666666014853874, 204.02777159555063], "size": [430, 230], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 151}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 162}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [165]}], "title": "CLIP Text Encode (Positive Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 67, "type": "CLIPTextEncode", "pos": [43.666666014853874, 534.0277718670993], "size": [430, 88], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 152}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [166]}], "title": "CLIP Text Encode (Negative Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 72, "type": "VAELoader", "pos": [-386.3333318901398, 523.8611624133522], "size": [364.42708333333337, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 179}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [148, 176]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "VAELoader", "models": [{"name": "flux2-vae.safetensors", "url": "https://huggingface.co/Comfy-Org/flux2-dev/resolve/main/split_files/vae/flux2-vae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["flux2-vae.safetensors"]}, {"id": 66, "type": "EmptyFlux2LatentImage", "pos": [570, 740], "size": [270, 106], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 172}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 174}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [146]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "EmptyFlux2LatentImage", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1024, 1024, 1]}, {"id": 80, "type": "ImageScaleToTotalPixels", "pos": [-391.6666683297289, 715.194415255584], "size": [270, 106], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 175}, {"localized_name": "upscale_method", "name": "upscale_method", "type": "COMBO", "widget": {"name": "upscale_method"}, "link": null}, {"localized_name": "megapixels", "name": "megapixels", "type": "FLOAT", "widget": {"name": "megapixels"}, "link": null}, {"localized_name": "resolution_steps", "name": "resolution_steps", "type": "INT", "widget": {"name": "resolution_steps"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [169, 170]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "ImageScaleToTotalPixels", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["nearest-exact", 1, 1]}, {"id": 79, "type": "6007e698-2ebd-4917-84d8-299b35d7b7ab", "pos": [238.33332484215495, 835.1944447404384], "size": [240, 86], "flags": {}, "order": 12, "mode": 0, "inputs": [{"label": "positive", "name": "conditioning", "type": "CONDITIONING", "link": 165}, {"label": "negative", "name": "conditioning_1", "type": "CONDITIONING", "link": 166}, {"name": "pixels", "type": "IMAGE", "link": 169}, {"name": "vae", "type": "VAE", "link": 176}], "outputs": [{"label": "positive", "name": "CONDITIONING", "type": "CONDITIONING", "links": [167]}, {"label": "negative", "name": "CONDITIONING_1", "type": "CONDITIONING", "links": [168]}], "properties": {"proxyWidgets": [], "cnr_id": "comfy-core", "ver": "0.8.2", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 81, "type": "GetImageSize", "pos": [310, 720], "size": [187.5, 66], "flags": {}, "order": 14, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 170}], "outputs": [{"localized_name": "width", "name": "width", "type": "INT", "links": [171, 172]}, {"localized_name": "height", "name": "height", "type": "INT", "links": [173, 174]}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "links": null}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "GetImageSize", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 64, "type": "SamplerCustomAdvanced", "pos": [860, 220], "size": [212.3638671875, 106], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 142}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 143}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 144}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 145}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 146}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "links": [147]}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "SamplerCustomAdvanced", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 73, "type": "RandomNoise", "pos": [560, 200], "size": [270, 82], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "links": [142]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "RandomNoise", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize"]}], "groups": [{"id": 1, "title": "Models", "bounding": [-390, 120, 380, 550], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Prompt", "bounding": [30, 120, 470, 550], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Sampler", "bounding": [540, 120, 532.3638671875, 550], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 139, "origin_id": 70, "origin_slot": 0, "target_id": 63, "target_slot": 0, "type": "MODEL"}, {"id": 142, "origin_id": 73, "origin_slot": 0, "target_id": 64, "target_slot": 0, "type": "NOISE"}, {"id": 143, "origin_id": 63, "origin_slot": 0, "target_id": 64, "target_slot": 1, "type": "GUIDER"}, {"id": 144, "origin_id": 61, "origin_slot": 0, "target_id": 64, "target_slot": 2, "type": "SAMPLER"}, {"id": 145, "origin_id": 62, "origin_slot": 0, "target_id": 64, "target_slot": 3, "type": "SIGMAS"}, {"id": 146, "origin_id": 66, "origin_slot": 0, "target_id": 64, "target_slot": 4, "type": "LATENT"}, {"id": 147, "origin_id": 64, "origin_slot": 0, "target_id": 65, "target_slot": 0, "type": "LATENT"}, {"id": 148, "origin_id": 72, "origin_slot": 0, "target_id": 65, "target_slot": 1, "type": "VAE"}, {"id": 152, "origin_id": 71, "origin_slot": 0, "target_id": 67, "target_slot": 0, "type": "CLIP"}, {"id": 151, "origin_id": 71, "origin_slot": 0, "target_id": 74, "target_slot": 0, "type": "CLIP"}, {"id": 153, "origin_id": 65, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 162, "origin_id": -10, "origin_slot": 0, "target_id": 74, "target_slot": 1, "type": "STRING"}, {"id": 165, "origin_id": 74, "origin_slot": 0, "target_id": 79, "target_slot": 0, "type": "CONDITIONING"}, {"id": 166, "origin_id": 67, "origin_slot": 0, "target_id": 79, "target_slot": 1, "type": "CONDITIONING"}, {"id": 167, "origin_id": 79, "origin_slot": 0, "target_id": 63, "target_slot": 1, "type": "CONDITIONING"}, {"id": 168, "origin_id": 79, "origin_slot": 1, "target_id": 63, "target_slot": 2, "type": "CONDITIONING"}, {"id": 169, "origin_id": 80, "origin_slot": 0, "target_id": 79, "target_slot": 2, "type": "IMAGE"}, {"id": 170, "origin_id": 80, "origin_slot": 0, "target_id": 81, "target_slot": 0, "type": "IMAGE"}, {"id": 171, "origin_id": 81, "origin_slot": 0, "target_id": 62, "target_slot": 1, "type": "INT"}, {"id": 172, "origin_id": 81, "origin_slot": 0, "target_id": 66, "target_slot": 0, "type": "INT"}, {"id": 173, "origin_id": 81, "origin_slot": 1, "target_id": 62, "target_slot": 2, "type": "INT"}, {"id": 174, "origin_id": 81, "origin_slot": 1, "target_id": 66, "target_slot": 1, "type": "INT"}, {"id": 175, "origin_id": -10, "origin_slot": 1, "target_id": 80, "target_slot": 0, "type": "IMAGE"}, {"id": 176, "origin_id": 72, "origin_slot": 0, "target_id": 79, "target_slot": 3, "type": "VAE"}, {"id": 177, "origin_id": -10, "origin_slot": 2, "target_id": 70, "target_slot": 0, "type": "COMBO"}, {"id": 178, "origin_id": -10, "origin_slot": 3, "target_id": 71, "target_slot": 0, "type": "COMBO"}, {"id": 179, "origin_id": -10, "origin_slot": 4, "target_id": 72, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image generation and editing/Edit image"}, {"id": "6007e698-2ebd-4917-84d8-299b35d7b7ab", "version": 1, "state": {"lastGroupId": 4, "lastNodeId": 81, "lastLinkId": 179, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Reference Conditioning", "inputNode": {"id": -10, "bounding": [-270, 990, 120, 120]}, "outputNode": {"id": -20, "bounding": [580, 970, 120, 80]}, "inputs": [{"id": "5c9a0f5e-8cee-4947-90bc-330de782043a", "name": "conditioning", "type": "CONDITIONING", "linkIds": [165], "label": "positive", "pos": [-170, 1010]}, {"id": "61826d46-4c21-4ad6-801c-3e3fa94115e2", "name": "conditioning_1", "type": "CONDITIONING", "linkIds": [166], "label": "negative", "pos": [-170, 1030]}, {"id": "345bf085-5939-47ff-9767-8f8f239a719c", "name": "pixels", "type": "IMAGE", "linkIds": [167], "pos": [-170, 1050]}, {"id": "f4594e34-e2f5-4f1e-b1fa-a1dc2aeb0a90", "name": "vae", "type": "VAE", "linkIds": [168], "pos": [-170, 1070]}], "outputs": [{"id": "b3357c0e-6428-4055-9cd3-3595f0896fa8", "name": "CONDITIONING", "type": "CONDITIONING", "linkIds": [169], "label": "positive", "pos": [600, 990]}, {"id": "01519713-2ed1-4694-a387-79f44e088e89", "name": "CONDITIONING_1", "type": "CONDITIONING", "linkIds": [170], "label": "negative", "pos": [600, 1010]}], "widgets": [], "nodes": [{"id": 76, "type": "ReferenceLatent", "pos": [170, 1050], "size": [204.134765625, 46], "flags": {"collapsed": false}, "order": 0, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 166}, {"localized_name": "latent", "name": "latent", "shape": 7, "type": "LATENT", "link": 163}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [170]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "ReferenceLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 78, "type": "VAEEncode", "pos": [-90, 1150], "size": [190, 46], "flags": {"collapsed": false}, "order": 2, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 167}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 168}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [163, 164]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "VAEEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 77, "type": "ReferenceLatent", "pos": [170, 940], "size": [210, 46], "flags": {"collapsed": false}, "order": 1, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 165}, {"localized_name": "latent", "name": "latent", "shape": 7, "type": "LATENT", "link": 164}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [169]}], "properties": {"cnr_id": "comfy-core", "ver": "0.8.2", "Node name for S&R": "ReferenceLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}], "groups": [], "links": [{"id": 163, "origin_id": 78, "origin_slot": 0, "target_id": 76, "target_slot": 1, "type": "LATENT"}, {"id": 164, "origin_id": 78, "origin_slot": 0, "target_id": 77, "target_slot": 1, "type": "LATENT"}, {"id": 165, "origin_id": -10, "origin_slot": 0, "target_id": 77, "target_slot": 0, "type": "CONDITIONING"}, {"id": 166, "origin_id": -10, "origin_slot": 1, "target_id": 76, "target_slot": 0, "type": "CONDITIONING"}, {"id": 167, "origin_id": -10, "origin_slot": 2, "target_id": 78, "target_slot": 0, "type": "IMAGE"}, {"id": 168, "origin_id": -10, "origin_slot": 3, "target_id": 78, "target_slot": 1, "type": "VAE"}, {"id": 169, "origin_id": 77, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "CONDITIONING"}, {"id": 170, "origin_id": 76, "origin_slot": 0, "target_id": -20, "target_slot": 1, "type": "CONDITIONING"}], "extra": {"workflowRendererVersion": "LG"}}]}, "config": {}, "extra": {"workflowRendererVersion": "LG", "ds": {"scale": 1.1478862047043865, "offset": [302.91933883258804, -648.9802050882657]}}, "version": 0.4} diff --git a/blueprints/Image Edit (Qwen 2511).json b/blueprints/Image Edit (Qwen 2511).json new file mode 100644 index 000000000..33e85333b --- /dev/null +++ b/blueprints/Image Edit (Qwen 2511).json @@ -0,0 +1 @@ +{"id": "d84b7d1a-a73f-4e31-bd16-983ac0cf5f1b", "revision": 0, "last_node_id": 17, "last_link_id": 32, "nodes": [{"id": 17, "type": "9fa6af8b-8c99-4446-8681-bccf8ba4ea54", "pos": [183.33334355513557, -120.00000702649223], "size": [383.0729166666667, 381.10677083333337], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image 1", "name": "image1", "type": "IMAGE", "link": null}, {"label": "image 2 (optional)", "name": "image2", "type": "IMAGE", "link": null}, {"label": "image 3 (optional)", "name": "image3", "type": "IMAGE", "link": null}, {"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": null}], "properties": {"proxyWidgets": [["-1", "prompt"], ["15", "seed"], ["15", "control_after_generate"], ["-1", "unet_name"], ["-1", "clip_name"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.11.0"}, "widgets_values": ["", null, null, "qwen_image_edit_2511_bf16.safetensors", "qwen_2.5_vl_7b_fp8_scaled.safetensors", "qwen_image_vae.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "9fa6af8b-8c99-4446-8681-bccf8ba4ea54", "version": 1, "state": {"lastGroupId": 2, "lastNodeId": 17, "lastLinkId": 32, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Image Edit (Qwen 2511)", "inputNode": {"id": -10, "bounding": [-412.6162343565087, 327.2321295314722, 142.59765625, 180]}, "outputNode": {"id": -20, "bounding": [1631.0466138212807, 305.6854343585077, 120, 60]}, "inputs": [{"id": "6e401a3f-21a6-4552-8ee4-179c313c1910", "name": "image1", "type": "IMAGE", "linkIds": [25], "label": "image 1", "pos": [-290.0185781065087, 347.2321295314722]}, {"id": "a0a6307b-62b8-481e-bb17-d332eceadbe4", "name": "image2", "type": "IMAGE", "linkIds": [21, 26], "label": "image 2 (optional)", "pos": [-290.0185781065087, 367.2321295314722]}, {"id": "232fe944-fc3f-43dd-bb34-112d0360cb5f", "name": "image3", "type": "IMAGE", "linkIds": [22, 27], "label": "image 3 (optional)", "pos": [-290.0185781065087, 387.2321295314722]}, {"id": "9b8ed2f4-5875-4f59-b4c1-5ab79a412f4e", "name": "prompt", "type": "STRING", "linkIds": [23], "pos": [-290.0185781065087, 407.2321295314722]}, {"id": "403a6bd0-f170-4cfb-b72e-cd7fa1dbcd06", "name": "unet_name", "type": "COMBO", "linkIds": [30], "pos": [-290.0185781065087, 427.2321295314722]}, {"id": "86a53531-2fab-47da-9525-858c80737044", "name": "clip_name", "type": "COMBO", "linkIds": [31], "pos": [-290.0185781065087, 447.2321295314722]}, {"id": "499f39e9-d698-41dc-b126-b7ea6024cf5d", "name": "vae_name", "type": "COMBO", "linkIds": [32], "pos": [-290.0185781065087, 467.2321295314722]}], "outputs": [{"id": "f2ccd1fa-428e-4127-89a6-760906013172", "name": "IMAGE", "type": "IMAGE", "linkIds": [24], "pos": [1651.0466138212807, 325.6854343585077]}], "widgets": [], "nodes": [{"id": 2, "type": "ModelSamplingAuraFlow", "pos": [791.0465113899395, -54.3145423152618], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 29}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [4]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "ModelSamplingAuraFlow", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3.1]}, {"id": 3, "type": "VAELoader", "pos": [-174.9530552190643, 462.6706561999898], "size": [396.1328125, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 32}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [6, 10, 12, 15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "VAELoader", "models": [{"name": "qwen_image_vae.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/resolve/main/split_files/vae/qwen_image_vae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_image_vae.safetensors"]}, {"id": 4, "type": "UNETLoader", "pos": [-174.9530552190643, -23.329297689188216], "size": [396.1328125, 82], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 30}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [29]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "UNETLoader", "models": [{"name": "qwen_image_edit_2511_bf16.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image-Edit_ComfyUI/resolve/main/split_files/diffusion_models/qwen_image_edit_2511_bf16.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_image_edit_2511_bf16.safetensors", "default"]}, {"id": 5, "type": "FluxKontextMultiReferenceLatentMethod", "pos": [781.0466382725523, 315.68545764091465], "size": [309.66145833333337, 58], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 2}, {"localized_name": "reference_latents_method", "name": "reference_latents_method", "type": "COMBO", "widget": {"name": "reference_latents_method"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [18]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "FluxKontextMultiReferenceLatentMethod", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["index_timestep_zero"], "color": "#222", "bgcolor": "#000"}, {"id": 6, "type": "FluxKontextMultiReferenceLatentMethod", "pos": [781.0466382725523, 185.68543791920104], "size": [309.66145833333337, 58], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 3}, {"localized_name": "reference_latents_method", "name": "reference_latents_method", "type": "COMBO", "widget": {"name": "reference_latents_method"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [17]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "FluxKontextMultiReferenceLatentMethod", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["index_timestep_zero"], "color": "#222", "bgcolor": "#000"}, {"id": 7, "type": "CFGNorm", "pos": [791.0465113899395, 55.68545297239743], "size": [270, 58], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 4}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}], "outputs": [{"localized_name": "patched_model", "name": "patched_model", "type": "MODEL", "links": [16]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "CFGNorm", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1]}, {"id": 8, "type": "MarkdownNote", "pos": [1111.0466241355298, 555.6854726502876], "size": [270, 195.10416666666669], "flags": {}, "order": 0, "mode": 0, "inputs": [], "outputs": [], "title": "KSampler settings", "properties": {}, "widgets_values": ["You can test and find the best setting by yourself. The following table is for reference.\n| | Qwen | Comfy | lightning LoRA |\n|--------|---------|------------|---------------------------|\n| Steps | 40 | 20 | 4 |\n| CFG | 4.0 | 4.0 | 1.0 |\n\nBy default, we use 20 steps as we just don't want it to take too long. Try 40 if you want a better result, but it will take longer."], "color": "#222", "bgcolor": "#000"}, {"id": 9, "type": "TextEncodeQwenImageEditPlus", "pos": [301.0466082538065, 305.6854454238875], "size": [420, 170], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 5}, {"localized_name": "vae", "name": "vae", "shape": 7, "type": "VAE", "link": 6}, {"localized_name": "image1", "name": "image1", "shape": 7, "type": "IMAGE", "link": 28}, {"localized_name": "image2", "name": "image2", "shape": 7, "type": "IMAGE", "link": 21}, {"localized_name": "image3", "name": "image3", "shape": 7, "type": "IMAGE", "link": 22}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [2]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "TextEncodeQwenImageEditPlus", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 10, "type": "Note", "pos": [801.0465236069665, 435.6854651456011], "size": [280, 88], "flags": {}, "order": 1, "mode": 0, "inputs": [], "outputs": [], "properties": {}, "widgets_values": ["The \"Edit Model Reference Method\" nodes above are not needed if you use Comfy files, but may be needed if you use repackaged ones from other people."], "color": "#432", "bgcolor": "#653"}, {"id": 13, "type": "TextEncodeQwenImageEditPlus", "pos": [301.0466082538065, -14.314562996972978], "size": [426.6276041666667, 215.55989583333334], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 11}, {"localized_name": "vae", "name": "vae", "shape": 7, "type": "VAE", "link": 12}, {"localized_name": "image1", "name": "image1", "shape": 7, "type": "IMAGE", "link": 13}, {"localized_name": "image2", "name": "image2", "shape": 7, "type": "IMAGE", "link": 26}, {"localized_name": "image3", "name": "image3", "shape": 7, "type": "IMAGE", "link": 27}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 23}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [3]}], "title": "TextEncodeQwenImageEditPlus (Positive)", "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "TextEncodeQwenImageEditPlus", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 14, "type": "VAEEncode", "pos": [511.0465866120977, 645.6854435038923], "size": [187.5, 46], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 14}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 15}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [19]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "VAEEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 15, "type": "KSampler", "pos": [1101.0466119185025, -54.3145423152618], "size": [280, 510], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 16}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 17}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 18}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 19}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [9]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "KSampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize", 40, 4, "euler", "simple", 1]}, {"id": 12, "type": "VAEDecode", "pos": [1431.0464586818402, -44.31456487314459], "size": [187.5, 46], "flags": {"collapsed": false}, "order": 10, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 9}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 10}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [24]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 16, "type": "FluxKontextImageScale", "pos": [-170, 630], "size": [194.9458984375, 26], "flags": {}, "order": 14, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 25}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [7, 13, 14, 28]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "FluxKontextImageScale", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 1, "type": "CLIPLoader", "pos": [-170, 200], "size": [396.1328125, 106], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 31}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [5, 11]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "CLIPLoader", "models": [{"name": "qwen_2.5_vl_7b_fp8_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/HunyuanVideo_1.5_repackaged/resolve/main/split_files/text_encoders/qwen_2.5_vl_7b_fp8_scaled.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_2.5_vl_7b_fp8_scaled.safetensors", "qwen_image", "default"]}], "groups": [{"id": 1, "title": "Models", "bounding": [-180, -90, 416.1419982910156, 630.0299011230469], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Prompt", "bounding": [250, -90, 510, 630], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 2, "origin_id": 9, "origin_slot": 0, "target_id": 5, "target_slot": 0, "type": "CONDITIONING"}, {"id": 3, "origin_id": 13, "origin_slot": 0, "target_id": 6, "target_slot": 0, "type": "CONDITIONING"}, {"id": 4, "origin_id": 2, "origin_slot": 0, "target_id": 7, "target_slot": 0, "type": "MODEL"}, {"id": 5, "origin_id": 1, "origin_slot": 0, "target_id": 9, "target_slot": 0, "type": "CLIP"}, {"id": 6, "origin_id": 3, "origin_slot": 0, "target_id": 9, "target_slot": 1, "type": "VAE"}, {"id": 9, "origin_id": 15, "origin_slot": 0, "target_id": 12, "target_slot": 0, "type": "LATENT"}, {"id": 10, "origin_id": 3, "origin_slot": 0, "target_id": 12, "target_slot": 1, "type": "VAE"}, {"id": 11, "origin_id": 1, "origin_slot": 0, "target_id": 13, "target_slot": 0, "type": "CLIP"}, {"id": 12, "origin_id": 3, "origin_slot": 0, "target_id": 13, "target_slot": 1, "type": "VAE"}, {"id": 13, "origin_id": 16, "origin_slot": 0, "target_id": 13, "target_slot": 2, "type": "IMAGE"}, {"id": 14, "origin_id": 16, "origin_slot": 0, "target_id": 14, "target_slot": 0, "type": "IMAGE"}, {"id": 15, "origin_id": 3, "origin_slot": 0, "target_id": 14, "target_slot": 1, "type": "VAE"}, {"id": 16, "origin_id": 7, "origin_slot": 0, "target_id": 15, "target_slot": 0, "type": "MODEL"}, {"id": 17, "origin_id": 6, "origin_slot": 0, "target_id": 15, "target_slot": 1, "type": "CONDITIONING"}, {"id": 18, "origin_id": 5, "origin_slot": 0, "target_id": 15, "target_slot": 2, "type": "CONDITIONING"}, {"id": 19, "origin_id": 14, "origin_slot": 0, "target_id": 15, "target_slot": 3, "type": "LATENT"}, {"id": 21, "origin_id": -10, "origin_slot": 1, "target_id": 9, "target_slot": 3, "type": "IMAGE"}, {"id": 22, "origin_id": -10, "origin_slot": 2, "target_id": 9, "target_slot": 4, "type": "IMAGE"}, {"id": 23, "origin_id": -10, "origin_slot": 3, "target_id": 13, "target_slot": 5, "type": "STRING"}, {"id": 24, "origin_id": 12, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 25, "origin_id": -10, "origin_slot": 0, "target_id": 16, "target_slot": 0, "type": "IMAGE"}, {"id": 26, "origin_id": -10, "origin_slot": 1, "target_id": 13, "target_slot": 3, "type": "IMAGE"}, {"id": 27, "origin_id": -10, "origin_slot": 2, "target_id": 13, "target_slot": 4, "type": "IMAGE"}, {"id": 28, "origin_id": 16, "origin_slot": 0, "target_id": 9, "target_slot": 2, "type": "IMAGE"}, {"id": 29, "origin_id": 4, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "MODEL"}, {"id": 30, "origin_id": -10, "origin_slot": 4, "target_id": 4, "target_slot": 0, "type": "COMBO"}, {"id": 31, "origin_id": -10, "origin_slot": 5, "target_id": 1, "target_slot": 0, "type": "COMBO"}, {"id": 32, "origin_id": -10, "origin_slot": 6, "target_id": 3, "target_slot": 0, "type": "COMBO"}], "extra": {"frontendVersion": "1.37.11", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true}, "category": "Image generation and editing/Edit image"}]}, "config": {}, "extra": {"frontendVersion": "1.37.11", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true, "ds": {"scale": 0.8597138248970195, "offset": [716.4750075519744, 479.19752576099086]}}, "version": 0.4} diff --git a/blueprints/Image Inpainting (Qwen-image).json b/blueprints/Image Inpainting (Qwen-image).json new file mode 100644 index 000000000..5f8ef81f9 --- /dev/null +++ b/blueprints/Image Inpainting (Qwen-image).json @@ -0,0 +1 @@ +{"id": "84318cde-a839-41d4-8632-df6d7c50ffc5", "revision": 0, "last_node_id": 256, "last_link_id": 403, "nodes": [{"id": 256, "type": "c93d5779-7bfe-4511-98e2-6a665ed0dff2", "pos": [2271.698367680439, -460.52399024524993], "size": [420, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": null}, {"localized_name": "mask", "name": "mask", "type": "MASK", "link": null}, {"name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}, {"name": "control_net_name", "type": "COMBO", "widget": {"name": "control_net_name"}, "link": null}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": null}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "clip_name"], ["-1", "vae_name"], ["-1", "control_net_name"], ["3", "seed"], ["3", "control_after_generate"]], "cnr_id": "comfy-core", "ver": "0.13.0"}, "widgets_values": ["", "qwen_2.5_vl_7b_fp8_scaled.safetensors", "qwen_image_vae.safetensors", "Qwen-Image-InstantX-ControlNet-Inpainting.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "c93d5779-7bfe-4511-98e2-6a665ed0dff2", "version": 1, "state": {"lastGroupId": 14, "lastNodeId": 256, "lastLinkId": 403, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Image Inpainting (Qwen-image)", "inputNode": {"id": -10, "bounding": [-860, 530, 140.587890625, 160]}, "outputNode": {"id": -20, "bounding": [1290, 530, 120, 60]}, "inputs": [{"id": "61dc027a-a7fc-4c40-8aa4-fd4a6e36d00f", "name": "image", "type": "IMAGE", "linkIds": [399], "localized_name": "image", "pos": [-739.412109375, 550]}, {"id": "28f4cf42-1c6d-49b8-abce-53ef9c628907", "name": "mask", "type": "MASK", "linkIds": [205], "localized_name": "mask", "pos": [-739.412109375, 570]}, {"id": "f082f9ab-9a31-4d99-b4fd-4900453a30a8", "name": "text", "type": "STRING", "linkIds": [394], "pos": [-739.412109375, 590]}, {"id": "9e692477-812a-4054-b780-471228a9821c", "name": "clip_name", "type": "COMBO", "linkIds": [401], "pos": [-739.412109375, 610]}, {"id": "dfbf7eac-1f92-4636-9ead-6a1c2595c5e2", "name": "vae_name", "type": "COMBO", "linkIds": [402], "pos": [-739.412109375, 630]}, {"id": "cfaf4549-e61b-4a88-a514-24894142433a", "name": "control_net_name", "type": "COMBO", "linkIds": [403], "pos": [-739.412109375, 650]}], "outputs": [{"id": "45b4d67e-3d8f-4936-9599-607a23161a3c", "name": "IMAGE", "type": "IMAGE", "linkIds": [400], "pos": [1310, 550]}], "widgets": [], "nodes": [{"id": 38, "type": "CLIPLoader", "pos": [-90, 70], "size": [380, 106], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 401}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "slot_index": 0, "links": [74, 75]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "CLIPLoader", "models": [{"name": "qwen_2.5_vl_7b_fp8_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/resolve/main/split_files/text_encoders/qwen_2.5_vl_7b_fp8_scaled.safetensors", "directory": "text_encoders"}]}, "widgets_values": ["qwen_2.5_vl_7b_fp8_scaled.safetensors", "qwen_image", "default"]}, {"id": 37, "type": "UNETLoader", "pos": [-90, -60], "size": [380, 82], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [145]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "UNETLoader", "models": [{"name": "qwen_image_fp8_e4m3fn.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/resolve/main/split_files/diffusion_models/qwen_image_fp8_e4m3fn.safetensors", "directory": "diffusion_models"}]}, "widgets_values": ["qwen_image_fp8_e4m3fn.safetensors", "default"]}, {"id": 7, "type": "CLIPTextEncode", "pos": [330, 320], "size": [460, 140], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 75}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [191]}], "title": "CLIP Text Encode (Negative Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "CLIPTextEncode"}, "widgets_values": [" "], "color": "#223", "bgcolor": "#335"}, {"id": 84, "type": "ControlNetLoader", "pos": [-90, 340], "size": [380, 58], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "control_net_name", "name": "control_net_name", "type": "COMBO", "widget": {"name": "control_net_name"}, "link": 403}], "outputs": [{"localized_name": "CONTROL_NET", "name": "CONTROL_NET", "type": "CONTROL_NET", "links": [192]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "ControlNetLoader", "models": [{"name": "Qwen-Image-InstantX-ControlNet-Inpainting.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image-InstantX-ControlNets/resolve/main/split_files/controlnet/Qwen-Image-InstantX-ControlNet-Inpainting.safetensors", "directory": "controlnet"}]}, "widgets_values": ["Qwen-Image-InstantX-ControlNet-Inpainting.safetensors"]}, {"id": 39, "type": "VAELoader", "pos": [-90, 230], "size": [380, 58], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 402}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [76, 144, 193]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "VAELoader", "models": [{"name": "qwen_image_vae.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/resolve/main/split_files/vae/qwen_image_vae.safetensors", "directory": "vae"}]}, "widgets_values": ["qwen_image_vae.safetensors"]}, {"id": 66, "type": "ModelSamplingAuraFlow", "pos": [860, -100], "size": [310, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 149}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [156]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "ModelSamplingAuraFlow"}, "widgets_values": [3.1000000000000005]}, {"id": 108, "type": "ControlNetInpaintingAliMamaApply", "pos": [430, 560], "size": [317.0093688964844, 206], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 190}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 191}, {"localized_name": "control_net", "name": "control_net", "type": "CONTROL_NET", "link": 192}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 193}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 397}, {"localized_name": "mask", "name": "mask", "type": "MASK", "link": 220}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}, {"localized_name": "start_percent", "name": "start_percent", "type": "FLOAT", "widget": {"name": "start_percent"}, "link": null}, {"localized_name": "end_percent", "name": "end_percent", "type": "FLOAT", "widget": {"name": "end_percent"}, "link": null}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [188]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [189]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ControlNetInpaintingAliMamaApply"}, "widgets_values": [1, 0, 1]}, {"id": 86, "type": "Note", "pos": [860, 500], "size": [307.4002380371094, 127.38092803955078], "flags": {}, "order": 1, "mode": 0, "inputs": [], "outputs": [], "properties": {}, "widgets_values": ["Set cfg to 1.0 for a speed boost at the cost of consistency. Samplers like res_multistep work pretty well at cfg 1.0\n\nThe official number of steps is 50 but I think that's too much. Even just 10 steps seems to work."], "color": "#432", "bgcolor": "#653"}, {"id": 76, "type": "VAEEncode", "pos": [430, 830], "size": [140, 46], "flags": {"collapsed": true}, "order": 11, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 396}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 144}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [208]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "VAEEncode"}, "widgets_values": []}, {"id": 122, "type": "SetLatentNoiseMask", "pos": [430, 890], "size": [230, 50], "flags": {"collapsed": true}, "order": 15, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 208}, {"localized_name": "mask", "name": "mask", "type": "MASK", "link": 219}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [210]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "SetLatentNoiseMask"}, "widgets_values": []}, {"id": 223, "type": "MarkdownNote", "pos": [860, 670], "size": [300, 160], "flags": {}, "order": 2, "mode": 0, "inputs": [], "outputs": [], "title": "Note: KSampler settings", "properties": {}, "widgets_values": ["You can test and find the best setting by yourself. The following table is for reference.\n| Parameters | Qwen Team | Comfy Original | with 4steps LoRA |\n|--------|---------|------------|---------------------------|\n| Steps | 50 | 20 | 4 |\n| CFG | 4.0 | 2.5 | 1.0 |"], "color": "#432", "bgcolor": "#653"}, {"id": 80, "type": "LoraLoaderModelOnly", "pos": [350, -70], "size": [430, 82], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 145}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": null}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [149]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "Qwen-Image-Lightning-4steps-V1.0.safetensors", "url": "https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0.safetensors", "directory": "loras"}]}, "widgets_values": ["Qwen-Image-Lightning-4steps-V1.0.safetensors", 1]}, {"id": 6, "type": "CLIPTextEncode", "pos": [330, 110], "size": [460, 164.31304931640625], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 74}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 394}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [190]}], "title": "CLIP Text Encode (Positive Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 121, "type": "56a1f603-fbd2-40ed-94ef-c9ecbd96aca8", "pos": [430, 950], "size": [330, 100], "flags": {}, "order": 14, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 205}, {"name": "expand", "type": "INT", "widget": {"name": "expand"}, "link": null}, {"name": "blur_radius", "type": "INT", "widget": {"name": "blur_radius"}, "link": null}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [215, 219, 220]}], "properties": {"proxyWidgets": [["-1", "expand"], ["-1", "blur_radius"]], "cnr_id": "comfy-core", "ver": "0.3.59"}, "widgets_values": [0, 1]}, {"id": 3, "type": "KSampler", "pos": [860, 20], "size": [310, 430], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 156}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 188}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 189}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 210}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [128]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "KSampler"}, "widgets_values": [0, "randomize", 4, 1, "euler", "simple", 1]}, {"id": 224, "type": "FluxKontextImageScale", "pos": [10, 1090], "size": [194.9458984375, 26], "flags": {}, "order": 17, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 399}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [396, 397]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "FluxKontextImageScale"}, "widgets_values": []}, {"id": 8, "type": "VAEDecode", "pos": [900, 880], "size": [250, 46], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 128}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 76}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [110, 400]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "VAEDecode"}, "widgets_values": []}, {"id": 124, "type": "MaskPreview", "pos": [440, 1100], "size": [320, 340], "flags": {}, "order": 16, "mode": 4, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 215}], "outputs": [], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "MaskPreview"}, "widgets_values": []}], "groups": [{"id": 1, "title": "Step 1 - Upload models", "bounding": [-100, -140, 400, 610], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 4, "title": "Step 3 - Prompt", "bounding": [320, 40, 490, 430], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 5, "title": "4 steps lightning LoRA", "bounding": [320, -140, 490, 160], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 14, "title": "Inpainting", "bounding": [-110, -180, 1340, 1650], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 75, "origin_id": 38, "origin_slot": 0, "target_id": 7, "target_slot": 0, "type": "CLIP"}, {"id": 149, "origin_id": 80, "origin_slot": 0, "target_id": 66, "target_slot": 0, "type": "MODEL"}, {"id": 190, "origin_id": 6, "origin_slot": 0, "target_id": 108, "target_slot": 0, "type": "CONDITIONING"}, {"id": 191, "origin_id": 7, "origin_slot": 0, "target_id": 108, "target_slot": 1, "type": "CONDITIONING"}, {"id": 192, "origin_id": 84, "origin_slot": 0, "target_id": 108, "target_slot": 2, "type": "CONTROL_NET"}, {"id": 193, "origin_id": 39, "origin_slot": 0, "target_id": 108, "target_slot": 3, "type": "VAE"}, {"id": 220, "origin_id": 121, "origin_slot": 0, "target_id": 108, "target_slot": 5, "type": "MASK"}, {"id": 144, "origin_id": 39, "origin_slot": 0, "target_id": 76, "target_slot": 1, "type": "VAE"}, {"id": 208, "origin_id": 76, "origin_slot": 0, "target_id": 122, "target_slot": 0, "type": "LATENT"}, {"id": 219, "origin_id": 121, "origin_slot": 0, "target_id": 122, "target_slot": 1, "type": "MASK"}, {"id": 215, "origin_id": 121, "origin_slot": 0, "target_id": 124, "target_slot": 0, "type": "MASK"}, {"id": 128, "origin_id": 3, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "LATENT"}, {"id": 76, "origin_id": 39, "origin_slot": 0, "target_id": 8, "target_slot": 1, "type": "VAE"}, {"id": 74, "origin_id": 38, "origin_slot": 0, "target_id": 6, "target_slot": 0, "type": "CLIP"}, {"id": 145, "origin_id": 37, "origin_slot": 0, "target_id": 80, "target_slot": 0, "type": "MODEL"}, {"id": 156, "origin_id": 66, "origin_slot": 0, "target_id": 3, "target_slot": 0, "type": "MODEL"}, {"id": 188, "origin_id": 108, "origin_slot": 0, "target_id": 3, "target_slot": 1, "type": "CONDITIONING"}, {"id": 189, "origin_id": 108, "origin_slot": 1, "target_id": 3, "target_slot": 2, "type": "CONDITIONING"}, {"id": 210, "origin_id": 122, "origin_slot": 0, "target_id": 3, "target_slot": 3, "type": "LATENT"}, {"id": 205, "origin_id": -10, "origin_slot": 1, "target_id": 121, "target_slot": 0, "type": "MASK"}, {"id": 110, "origin_id": 8, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 394, "origin_id": -10, "origin_slot": 2, "target_id": 6, "target_slot": 1, "type": "STRING"}, {"id": 396, "origin_id": 224, "origin_slot": 0, "target_id": 76, "target_slot": 0, "type": "IMAGE"}, {"id": 397, "origin_id": 224, "origin_slot": 0, "target_id": 108, "target_slot": 4, "type": "IMAGE"}, {"id": 399, "origin_id": -10, "origin_slot": 0, "target_id": 224, "target_slot": 0, "type": "IMAGE"}, {"id": 400, "origin_id": 8, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 401, "origin_id": -10, "origin_slot": 3, "target_id": 38, "target_slot": 0, "type": "COMBO"}, {"id": 402, "origin_id": -10, "origin_slot": 4, "target_id": 39, "target_slot": 0, "type": "COMBO"}, {"id": 403, "origin_id": -10, "origin_slot": 5, "target_id": 84, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image generation and editing/Inpaint image"}, {"id": "56a1f603-fbd2-40ed-94ef-c9ecbd96aca8", "version": 1, "state": {"lastGroupId": 14, "lastNodeId": 256, "lastLinkId": 403, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Grow and Blur Mask", "inputNode": {"id": -10, "bounding": [290, 3536, 120, 100]}, "outputNode": {"id": -20, "bounding": [1130, 3536, 120, 60]}, "inputs": [{"id": "3ac60d5e-8f9d-4663-9b24-b3a15a3e9e20", "name": "mask", "type": "MASK", "linkIds": [279], "localized_name": "mask", "pos": [390, 3556]}, {"id": "d1ab0cf5-7062-41ac-9f4b-8c660fc4a714", "name": "expand", "type": "INT", "linkIds": [379], "pos": [390, 3576]}, {"id": "1a787af5-da9f-44c5-9f5a-3f71609ca0ef", "name": "blur_radius", "type": "INT", "linkIds": [380], "pos": [390, 3596]}], "outputs": [{"id": "1f97f683-13d3-4871-876d-678fca850d89", "name": "MASK", "type": "MASK", "linkIds": [378], "localized_name": "MASK", "pos": [1150, 3556]}], "widgets": [], "nodes": [{"id": 253, "type": "ImageToMask", "pos": [800, 3630], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 377}, {"localized_name": "channel", "name": "channel", "type": "COMBO", "widget": {"name": "channel"}, "link": null}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [378]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ImageToMask"}, "widgets_values": ["red"]}, {"id": 251, "type": "MaskToImage", "pos": [780, 3470], "size": [260, 70], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 372}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [373]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "MaskToImage"}, "widgets_values": []}, {"id": 199, "type": "GrowMask", "pos": [470, 3460], "size": [270, 82], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 279}, {"localized_name": "expand", "name": "expand", "type": "INT", "widget": {"name": "expand"}, "link": 379}, {"localized_name": "tapered_corners", "name": "tapered_corners", "type": "BOOLEAN", "widget": {"name": "tapered_corners"}, "link": null}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [372]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "GrowMask"}, "widgets_values": [0, true]}, {"id": 252, "type": "ImageBlur", "pos": [480, 3620], "size": [270, 82], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 373}, {"localized_name": "blur_radius", "name": "blur_radius", "type": "INT", "widget": {"name": "blur_radius"}, "link": 380}, {"localized_name": "sigma", "name": "sigma", "type": "FLOAT", "widget": {"name": "sigma"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [377]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ImageBlur"}, "widgets_values": [1, 1]}], "groups": [], "links": [{"id": 373, "origin_id": 251, "origin_slot": 0, "target_id": 252, "target_slot": 0, "type": "IMAGE"}, {"id": 377, "origin_id": 252, "origin_slot": 0, "target_id": 253, "target_slot": 0, "type": "IMAGE"}, {"id": 372, "origin_id": 199, "origin_slot": 0, "target_id": 251, "target_slot": 0, "type": "MASK"}, {"id": 279, "origin_id": -10, "origin_slot": 0, "target_id": 199, "target_slot": 0, "type": "MASK"}, {"id": 378, "origin_id": 253, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "MASK"}, {"id": 379, "origin_id": -10, "origin_slot": 1, "target_id": 199, "target_slot": 1, "type": "INT"}, {"id": 380, "origin_id": -10, "origin_slot": 2, "target_id": 252, "target_slot": 1, "type": "INT"}], "extra": {"workflowRendererVersion": "LG"}}]}, "config": {}, "extra": {"ds": {"scale": 1.088930769230769, "offset": [-1576.5829757292656, 657.608356702113]}, "workflowRendererVersion": "LG"}, "version": 0.4} diff --git a/blueprints/Image Levels.json b/blueprints/Image Levels.json new file mode 100644 index 000000000..f028662bd --- /dev/null +++ b/blueprints/Image Levels.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 139, "last_link_id": 0, "nodes": [{"id": 139, "type": "75bf8a72-aad8-4f3e-83ee-380e70248240", "pos": [620, 900], "size": [240, 178], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["5", "choice"], ["3", "value"], ["6", "value"], ["7", "value"], ["8", "value"], ["9", "value"]]}, "widgets_values": [], "title": "Image Levels"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "75bf8a72-aad8-4f3e-83ee-380e70248240", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 144, "lastLinkId": 118, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image Levels", "inputNode": {"id": -10, "bounding": [3840, -3430, 120, 60]}, "outputNode": {"id": -20, "bounding": [4950, -3430, 120, 60]}, "inputs": [{"id": "b53e5012-fa47-400f-a324-28c74854ccae", "name": "images.image0", "type": "IMAGE", "linkIds": [1], "localized_name": "images.image0", "label": "image", "pos": [3940, -3410]}], "outputs": [{"id": "de7f2ffa-155f-41fd-b054-aa4d91ef49ca", "name": "IMAGE0", "type": "IMAGE", "linkIds": [8], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [4970, -3410]}], "widgets": [], "nodes": [{"id": 5, "type": "CustomCombo", "pos": [4020, -3350], "size": [270, 198], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "channel", "localized_name": "choice", "name": "choice", "type": "COMBO", "widget": {"name": "choice"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": null}, {"localized_name": "INDEX", "name": "INDEX", "type": "INT", "links": [3]}], "title": "Channel", "properties": {"Node name for S&R": "CustomCombo"}, "widgets_values": ["RGB", 0, "RGB", "R", "G", "B", ""]}, {"id": 8, "type": "PrimitiveFloat", "pos": [4020, -3550], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "output_black", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [6]}], "title": "Output Black", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 255, "min": 0, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 0, 0]}, {"offset": 1, "color": [255, 255, 255]}]}, "widgets_values": [0]}, {"id": 3, "type": "PrimitiveFloat", "pos": [4020, -3850], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "input_black", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [2]}], "title": "Input Black", "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 255, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 0, 0]}, {"offset": 1, "color": [255, 255, 255]}]}, "widgets_values": [0]}, {"id": 6, "type": "PrimitiveFloat", "pos": [4020, -3750], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "input_white", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [4]}], "title": "Input White", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 255, "min": 0, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 0, 0]}, {"offset": 1, "color": [255, 255, 255]}]}, "widgets_values": [255]}, {"id": 7, "type": "PrimitiveFloat", "pos": [4020, -3650], "size": [270, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "gamma", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [5]}], "title": "Gamma", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 10, "min": 0, "step": 0.01, "precision": 2, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 0, 0]}, {"offset": 0.5, "color": [128, 128, 128]}, {"offset": 1, "color": [255, 255, 255]}]}, "widgets_values": [1]}, {"id": 9, "type": "PrimitiveFloat", "pos": [4020, -3450], "size": [270, 58], "flags": {}, "order": 5, "mode": 0, "inputs": [{"label": "output_white", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [7]}], "title": "Output White", "properties": {"Node name for S&R": "PrimitiveFloat", "max": 255, "min": 0, "step": 1, "display": "gradientslider", "gradient_stops": [{"offset": 0, "color": [0, 0, 0]}, {"offset": 1, "color": [255, 255, 255]}]}, "widgets_values": [255]}, {"id": 1, "type": "GLSLShader", "pos": [4310, -3850], "size": [580, 272], "flags": {}, "order": 6, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 1}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 2}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 4}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 5}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": 6}, {"label": "u_float4", "localized_name": "floats.u_float4", "name": "floats.u_float4", "shape": 7, "type": "FLOAT", "link": 7}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": 3}, {"label": "u_int1", "localized_name": "ints.u_int1", "name": "ints.u_int1", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [8]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\n// Levels Adjustment\n// u_int0: channel (0=RGB, 1=R, 2=G, 3=B) default: 0\n// u_float0: input black (0-255) default: 0\n// u_float1: input white (0-255) default: 255\n// u_float2: gamma (0.01-9.99) default: 1.0\n// u_float3: output black (0-255) default: 0\n// u_float4: output white (0-255) default: 255\n\nuniform sampler2D u_image0;\nuniform int u_int0;\nuniform float u_float0;\nuniform float u_float1;\nuniform float u_float2;\nuniform float u_float3;\nuniform float u_float4;\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nvec3 applyLevels(vec3 color, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {\n float inRange = max(inWhite - inBlack, 0.0001);\n vec3 result = clamp((color - inBlack) / inRange, 0.0, 1.0);\n result = pow(result, vec3(1.0 / gamma));\n result = mix(vec3(outBlack), vec3(outWhite), result);\n return result;\n}\n\nfloat applySingleChannel(float value, float inBlack, float inWhite, float gamma, float outBlack, float outWhite) {\n float inRange = max(inWhite - inBlack, 0.0001);\n float result = clamp((value - inBlack) / inRange, 0.0, 1.0);\n result = pow(result, 1.0 / gamma);\n result = mix(outBlack, outWhite, result);\n return result;\n}\n\nvoid main() {\n vec4 texColor = texture(u_image0, v_texCoord);\n vec3 color = texColor.rgb;\n \n float inBlack = u_float0 / 255.0;\n float inWhite = u_float1 / 255.0;\n float gamma = u_float2;\n float outBlack = u_float3 / 255.0;\n float outWhite = u_float4 / 255.0;\n \n vec3 result;\n \n if (u_int0 == 0) {\n result = applyLevels(color, inBlack, inWhite, gamma, outBlack, outWhite);\n }\n else if (u_int0 == 1) {\n result = color;\n result.r = applySingleChannel(color.r, inBlack, inWhite, gamma, outBlack, outWhite);\n }\n else if (u_int0 == 2) {\n result = color;\n result.g = applySingleChannel(color.g, inBlack, inWhite, gamma, outBlack, outWhite);\n }\n else if (u_int0 == 3) {\n result = color;\n result.b = applySingleChannel(color.b, inBlack, inWhite, gamma, outBlack, outWhite);\n }\n else {\n result = color;\n }\n \n fragColor = vec4(result, texColor.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 2, "origin_id": 3, "origin_slot": 0, "target_id": 1, "target_slot": 2, "type": "FLOAT"}, {"id": 4, "origin_id": 6, "origin_slot": 0, "target_id": 1, "target_slot": 3, "type": "FLOAT"}, {"id": 5, "origin_id": 7, "origin_slot": 0, "target_id": 1, "target_slot": 4, "type": "FLOAT"}, {"id": 6, "origin_id": 8, "origin_slot": 0, "target_id": 1, "target_slot": 5, "type": "FLOAT"}, {"id": 7, "origin_id": 9, "origin_slot": 0, "target_id": 1, "target_slot": 6, "type": "FLOAT"}, {"id": 3, "origin_id": 5, "origin_slot": 1, "target_id": 1, "target_slot": 7, "type": "INT"}, {"id": 1, "origin_id": -10, "origin_slot": 0, "target_id": 1, "target_slot": 0, "type": "IMAGE"}, {"id": 8, "origin_id": 1, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Color adjust"}]}, "extra": {}} diff --git a/blueprints/Image Outpainting (Qwen-Image).json b/blueprints/Image Outpainting (Qwen-Image).json new file mode 100644 index 000000000..f36e0bd77 --- /dev/null +++ b/blueprints/Image Outpainting (Qwen-Image).json @@ -0,0 +1 @@ +{"id": "8f79c27f-bec4-412e-9b82-7c5b3b778ecf", "revision": 0, "last_node_id": 255, "last_link_id": 401, "nodes": [{"id": 224, "type": "fbf07656-8ff8-4299-a3fc-7378e0f4a004", "pos": [3200, 740], "size": [400, 460], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": null}, {"name": "left", "type": "INT", "widget": {"name": "left"}, "link": null}, {"name": "top", "type": "INT", "widget": {"name": "top"}, "link": null}, {"name": "right", "type": "INT", "widget": {"name": "right"}, "link": null}, {"name": "bottom", "type": "INT", "widget": {"name": "bottom"}, "link": null}, {"name": "feathering", "type": "INT", "widget": {"name": "feathering"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}, {"name": "control_net_name", "type": "COMBO", "widget": {"name": "control_net_name"}, "link": null}, {"name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["182", "text"], ["-1", "left"], ["-1", "top"], ["-1", "right"], ["-1", "bottom"], ["-1", "feathering"], ["190", "seed"], ["190", "control_after_generate"], ["-1", "unet_name"], ["-1", "clip_name"], ["-1", "vae_name"], ["-1", "control_net_name"], ["-1", "lora_name"]], "cnr_id": "comfy-core", "ver": "0.13.0"}, "widgets_values": [null, 0, 0, 0, 0, 0, null, null, "qwen_image_fp8_e4m3fn.safetensors", "qwen_2.5_vl_7b_fp8_scaled.safetensors", "qwen_image_vae.safetensors", "Qwen-Image-InstantX-ControlNet-Inpainting.safetensors", "Qwen-Image-Lightning-4steps-V1.0.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "fbf07656-8ff8-4299-a3fc-7378e0f4a004", "version": 1, "state": {"lastGroupId": 14, "lastNodeId": 255, "lastLinkId": 401, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Image Outpainting (Qwen-Image)", "inputNode": {"id": -10, "bounding": [1940, 610, 140.587890625, 260]}, "outputNode": {"id": -20, "bounding": [4240, 765, 120, 60]}, "inputs": [{"id": "466b9998-797f-4c6f-92e9-39120712c1a9", "name": "image", "type": "IMAGE", "linkIds": [351], "localized_name": "image", "pos": [2060.587890625, 630]}, {"id": "c5befee8-d6c4-493e-8ae1-e09d46268d10", "name": "left", "type": "INT", "linkIds": [392], "pos": [2060.587890625, 650]}, {"id": "c0b028a1-fcc0-4a54-9bdf-fa9e76992c40", "name": "top", "type": "INT", "linkIds": [393], "pos": [2060.587890625, 670]}, {"id": "22e43278-694c-410f-9043-f88b8dfdca28", "name": "right", "type": "INT", "linkIds": [394], "pos": [2060.587890625, 690]}, {"id": "f19fec20-a43d-4562-a0f8-bd6955091c1b", "name": "bottom", "type": "INT", "linkIds": [395], "pos": [2060.587890625, 710]}, {"id": "ba832b36-2199-4e1e-a28d-5f2e8acc99a3", "name": "feathering", "type": "INT", "linkIds": [396], "pos": [2060.587890625, 730]}, {"id": "437d6324-2d3c-4c50-ac21-1ea9aab57f4e", "name": "unet_name", "type": "COMBO", "linkIds": [397], "pos": [2060.587890625, 750]}, {"id": "4d58dde7-4402-45d5-ade9-9c41e99e0757", "name": "clip_name", "type": "COMBO", "linkIds": [398], "pos": [2060.587890625, 770]}, {"id": "a7558cc4-d4c4-4b4a-b2a3-0d7229a8ff65", "name": "vae_name", "type": "COMBO", "linkIds": [399], "pos": [2060.587890625, 790]}, {"id": "7d8ffb86-2ff3-49fc-8e96-94d3e530f154", "name": "control_net_name", "type": "COMBO", "linkIds": [400], "pos": [2060.587890625, 810]}, {"id": "a81e0fa5-5984-47ae-bb4f-108a2b92d373", "name": "lora_name", "type": "COMBO", "linkIds": [401], "pos": [2060.587890625, 830]}], "outputs": [{"id": "506ced76-78be-4eb2-ae70-eaa708a4cb98", "name": "IMAGE", "type": "IMAGE", "linkIds": [314], "localized_name": "IMAGE", "pos": [4260, 785]}], "widgets": [], "nodes": [{"id": 174, "type": "CLIPLoader", "pos": [2430, 60], "size": [380, 106], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 398}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "slot_index": 0, "links": [296, 305]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "CLIPLoader", "models": [{"name": "qwen_2.5_vl_7b_fp8_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/resolve/main/split_files/text_encoders/qwen_2.5_vl_7b_fp8_scaled.safetensors", "directory": "text_encoders"}]}, "widgets_values": ["qwen_2.5_vl_7b_fp8_scaled.safetensors", "qwen_image", "default"]}, {"id": 175, "type": "UNETLoader", "pos": [2430, -70], "size": [380, 82], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 397}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [306]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "UNETLoader", "models": [{"name": "qwen_image_fp8_e4m3fn.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/resolve/main/split_files/diffusion_models/qwen_image_fp8_e4m3fn.safetensors", "directory": "diffusion_models"}]}, "widgets_values": ["qwen_image_fp8_e4m3fn.safetensors", "default"]}, {"id": 177, "type": "ControlNetLoader", "pos": [2430, 330], "size": [380, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "control_net_name", "name": "control_net_name", "type": "COMBO", "widget": {"name": "control_net_name"}, "link": 400}], "outputs": [{"localized_name": "CONTROL_NET", "name": "CONTROL_NET", "type": "CONTROL_NET", "links": [301]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "ControlNetLoader", "models": [{"name": "Qwen-Image-InstantX-ControlNet-Inpainting.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image-InstantX-ControlNets/resolve/main/split_files/controlnet/Qwen-Image-InstantX-ControlNet-Inpainting.safetensors", "directory": "controlnet"}]}, "widgets_values": ["Qwen-Image-InstantX-ControlNet-Inpainting.safetensors"]}, {"id": 180, "type": "ModelSamplingAuraFlow", "pos": [3400, -110], "size": [310, 58], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 298}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [308]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "ModelSamplingAuraFlow"}, "widgets_values": [3.1000000000000005]}, {"id": 185, "type": "LoraLoaderModelOnly", "pos": [2870, -80], "size": [430, 82], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 306}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 401}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [298]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "Qwen-Image-Lightning-4steps-V1.0.safetensors", "url": "https://huggingface.co/lightx2v/Qwen-Image-Lightning/resolve/main/Qwen-Image-Lightning-4steps-V1.0.safetensors", "directory": "loras"}]}, "widgets_values": ["Qwen-Image-Lightning-4steps-V1.0.safetensors", 1]}, {"id": 190, "type": "KSampler", "pos": [3400, 10], "size": [310, 474], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 308}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 386}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 387}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 358}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [312]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "KSampler"}, "widgets_values": [375729975350303, "randomize", 4, 1, "euler", "simple", 1]}, {"id": 220, "type": "f93c215e-c393-460e-9534-ed2c3d8a652e", "pos": [2480, 1450], "size": [330, 100], "flags": {}, "order": 17, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 377}, {"name": "expand", "type": "INT", "widget": {"name": "expand"}, "link": null}, {"name": "blur_radius", "type": "INT", "widget": {"name": "blur_radius"}, "link": null}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [374, 375, 376]}], "properties": {"proxyWidgets": [["-1", "expand"], ["-1", "blur_radius"]], "cnr_id": "comfy-core", "ver": "0.3.59"}, "widgets_values": [20, 31]}, {"id": 195, "type": "VAEEncode", "pos": [2950, 820], "size": [140, 46], "flags": {"collapsed": false}, "order": 11, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 371}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 317}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [358]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "VAEEncode"}, "widgets_values": []}, {"id": 181, "type": "ControlNetInpaintingAliMamaApply", "pos": [2940, 560], "size": [317.0093688964844, 206], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 299}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 300}, {"localized_name": "control_net", "name": "control_net", "type": "CONTROL_NET", "link": 301}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 384}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 385}, {"localized_name": "mask", "name": "mask", "type": "MASK", "link": 375}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}, {"localized_name": "start_percent", "name": "start_percent", "type": "FLOAT", "widget": {"name": "start_percent"}, "link": null}, {"localized_name": "end_percent", "name": "end_percent", "type": "FLOAT", "widget": {"name": "end_percent"}, "link": null}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [386]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [387]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ControlNetInpaintingAliMamaApply"}, "widgets_values": [1, 0, 1]}, {"id": 178, "type": "VAELoader", "pos": [2430, 220], "size": [380, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 399}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [313, 317, 384]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "VAELoader", "models": [{"name": "qwen_image_vae.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image_ComfyUI/resolve/main/split_files/vae/qwen_image_vae.safetensors", "directory": "vae"}]}, "widgets_values": ["qwen_image_vae.safetensors"]}, {"id": 182, "type": "CLIPTextEncode", "pos": [2850, 100], "size": [460, 164.31304931640625], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 305}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [299]}], "title": "CLIP Text Encode (Positive Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 176, "type": "CLIPTextEncode", "pos": [2850, 310], "size": [460, 140], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 296}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [300]}], "title": "CLIP Text Encode (Negative Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#223", "bgcolor": "#335"}, {"id": 191, "type": "VAEDecode", "pos": [3440, 580], "size": [250, 46], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 312}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 313}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [314, 323]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "VAEDecode"}, "widgets_values": []}, {"id": 219, "type": "2a4b2cc0-db37-4302-a067-da392f38f06b", "pos": [2480, 1260], "size": [280, 80], "flags": {}, "order": 16, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 365}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 366}, {"name": "value", "type": "INT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [377]}, {"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [369, 370, 371, 385]}], "properties": {"proxyWidgets": [["-1", "value"]], "cnr_id": "comfy-core", "ver": "0.3.65"}, "widgets_values": [1536]}, {"id": 207, "type": "MaskPreview", "pos": [3430, 1270], "size": [340, 430], "flags": {}, "order": 15, "mode": 4, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 376}], "outputs": [], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "MaskPreview"}, "widgets_values": []}, {"id": 203, "type": "PreviewImage", "pos": [2990, 1270], "size": [310, 430], "flags": {}, "order": 14, "mode": 4, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 370}], "outputs": [], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "PreviewImage"}, "widgets_values": []}, {"id": 200, "type": "ImageCompositeMasked", "pos": [3850, 1280], "size": [250, 150], "flags": {}, "order": 12, "mode": 4, "inputs": [{"localized_name": "destination", "name": "destination", "type": "IMAGE", "link": 369}, {"localized_name": "source", "name": "source", "type": "IMAGE", "link": 323}, {"localized_name": "mask", "name": "mask", "shape": 7, "type": "MASK", "link": 374}, {"localized_name": "x", "name": "x", "type": "INT", "widget": {"name": "x"}, "link": null}, {"localized_name": "y", "name": "y", "type": "INT", "widget": {"name": "y"}, "link": null}, {"localized_name": "resize_source", "name": "resize_source", "type": "BOOLEAN", "widget": {"name": "resize_source"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ImageCompositeMasked"}, "widgets_values": [0, 0, false]}, {"id": 202, "type": "ImagePadForOutpaint", "pos": [2490, 1030], "size": [270, 174], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 351}, {"localized_name": "left", "name": "left", "type": "INT", "widget": {"name": "left"}, "link": 392}, {"localized_name": "top", "name": "top", "type": "INT", "widget": {"name": "top"}, "link": 393}, {"localized_name": "right", "name": "right", "type": "INT", "widget": {"name": "right"}, "link": 394}, {"localized_name": "bottom", "name": "bottom", "type": "INT", "widget": {"name": "bottom"}, "link": 395}, {"localized_name": "feathering", "name": "feathering", "type": "INT", "widget": {"name": "feathering"}, "link": 396}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [366]}, {"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [365]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ImagePadForOutpaint"}, "widgets_values": [0, 0, 0, 0, 0]}], "groups": [{"id": 12, "title": "For outpainting Ctrl-B to enable", "bounding": [2410, -190, 1770, 1970], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 7, "title": "Step 1 - Upload models", "bounding": [2420, -150, 400, 610], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 9, "title": "Step 3 - Prompt", "bounding": [2840, 30, 490, 430], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 10, "title": "4 steps lightning LoRA", "bounding": [2840, -150, 490, 160], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 11, "title": "Ctrl-B to enable it", "bounding": [2420, 940, 430, 460], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 298, "origin_id": 185, "origin_slot": 0, "target_id": 180, "target_slot": 0, "type": "MODEL"}, {"id": 306, "origin_id": 175, "origin_slot": 0, "target_id": 185, "target_slot": 0, "type": "MODEL"}, {"id": 308, "origin_id": 180, "origin_slot": 0, "target_id": 190, "target_slot": 0, "type": "MODEL"}, {"id": 386, "origin_id": 181, "origin_slot": 0, "target_id": 190, "target_slot": 1, "type": "CONDITIONING"}, {"id": 387, "origin_id": 181, "origin_slot": 1, "target_id": 190, "target_slot": 2, "type": "CONDITIONING"}, {"id": 358, "origin_id": 195, "origin_slot": 0, "target_id": 190, "target_slot": 3, "type": "LATENT"}, {"id": 377, "origin_id": 219, "origin_slot": 0, "target_id": 220, "target_slot": 0, "type": "MASK"}, {"id": 371, "origin_id": 219, "origin_slot": 1, "target_id": 195, "target_slot": 0, "type": "IMAGE"}, {"id": 317, "origin_id": 178, "origin_slot": 0, "target_id": 195, "target_slot": 1, "type": "VAE"}, {"id": 299, "origin_id": 182, "origin_slot": 0, "target_id": 181, "target_slot": 0, "type": "CONDITIONING"}, {"id": 300, "origin_id": 176, "origin_slot": 0, "target_id": 181, "target_slot": 1, "type": "CONDITIONING"}, {"id": 301, "origin_id": 177, "origin_slot": 0, "target_id": 181, "target_slot": 2, "type": "CONTROL_NET"}, {"id": 384, "origin_id": 178, "origin_slot": 0, "target_id": 181, "target_slot": 3, "type": "VAE"}, {"id": 385, "origin_id": 219, "origin_slot": 1, "target_id": 181, "target_slot": 4, "type": "IMAGE"}, {"id": 375, "origin_id": 220, "origin_slot": 0, "target_id": 181, "target_slot": 5, "type": "MASK"}, {"id": 305, "origin_id": 174, "origin_slot": 0, "target_id": 182, "target_slot": 0, "type": "CLIP"}, {"id": 296, "origin_id": 174, "origin_slot": 0, "target_id": 176, "target_slot": 0, "type": "CLIP"}, {"id": 312, "origin_id": 190, "origin_slot": 0, "target_id": 191, "target_slot": 0, "type": "LATENT"}, {"id": 313, "origin_id": 178, "origin_slot": 0, "target_id": 191, "target_slot": 1, "type": "VAE"}, {"id": 365, "origin_id": 202, "origin_slot": 1, "target_id": 219, "target_slot": 0, "type": "MASK"}, {"id": 366, "origin_id": 202, "origin_slot": 0, "target_id": 219, "target_slot": 1, "type": "IMAGE"}, {"id": 376, "origin_id": 220, "origin_slot": 0, "target_id": 207, "target_slot": 0, "type": "MASK"}, {"id": 370, "origin_id": 219, "origin_slot": 1, "target_id": 203, "target_slot": 0, "type": "IMAGE"}, {"id": 369, "origin_id": 219, "origin_slot": 1, "target_id": 200, "target_slot": 0, "type": "IMAGE"}, {"id": 323, "origin_id": 191, "origin_slot": 0, "target_id": 200, "target_slot": 1, "type": "IMAGE"}, {"id": 374, "origin_id": 220, "origin_slot": 0, "target_id": 200, "target_slot": 2, "type": "MASK"}, {"id": 351, "origin_id": -10, "origin_slot": 0, "target_id": 202, "target_slot": 0, "type": "IMAGE"}, {"id": 314, "origin_id": 191, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 392, "origin_id": -10, "origin_slot": 1, "target_id": 202, "target_slot": 1, "type": "INT"}, {"id": 393, "origin_id": -10, "origin_slot": 2, "target_id": 202, "target_slot": 2, "type": "INT"}, {"id": 394, "origin_id": -10, "origin_slot": 3, "target_id": 202, "target_slot": 3, "type": "INT"}, {"id": 395, "origin_id": -10, "origin_slot": 4, "target_id": 202, "target_slot": 4, "type": "INT"}, {"id": 396, "origin_id": -10, "origin_slot": 5, "target_id": 202, "target_slot": 5, "type": "INT"}, {"id": 397, "origin_id": -10, "origin_slot": 6, "target_id": 175, "target_slot": 0, "type": "COMBO"}, {"id": 398, "origin_id": -10, "origin_slot": 7, "target_id": 174, "target_slot": 0, "type": "COMBO"}, {"id": 399, "origin_id": -10, "origin_slot": 8, "target_id": 178, "target_slot": 0, "type": "COMBO"}, {"id": 400, "origin_id": -10, "origin_slot": 9, "target_id": 177, "target_slot": 0, "type": "COMBO"}, {"id": 401, "origin_id": -10, "origin_slot": 10, "target_id": 185, "target_slot": 1, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image generation and editing/Outpaint image"}, {"id": "f93c215e-c393-460e-9534-ed2c3d8a652e", "version": 1, "state": {"lastGroupId": 14, "lastNodeId": 255, "lastLinkId": 401, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Grow and Blur Mask", "inputNode": {"id": -10, "bounding": [290, 3536, 120, 100]}, "outputNode": {"id": -20, "bounding": [1130, 3536, 120, 60]}, "inputs": [{"id": "3ac60d5e-8f9d-4663-9b24-b3a15a3e9e20", "name": "mask", "type": "MASK", "linkIds": [279], "localized_name": "mask", "pos": [390, 3556]}, {"id": "d1ab0cf5-7062-41ac-9f4b-8c660fc4a714", "name": "expand", "type": "INT", "linkIds": [379], "pos": [390, 3576]}, {"id": "1a787af5-da9f-44c5-9f5a-3f71609ca0ef", "name": "blur_radius", "type": "INT", "linkIds": [380], "pos": [390, 3596]}], "outputs": [{"id": "1f97f683-13d3-4871-876d-678fca850d89", "name": "MASK", "type": "MASK", "linkIds": [378], "localized_name": "MASK", "pos": [1150, 3556]}], "widgets": [], "nodes": [{"id": 253, "type": "ImageToMask", "pos": [800, 3630], "size": [270, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 377}, {"localized_name": "channel", "name": "channel", "type": "COMBO", "widget": {"name": "channel"}, "link": null}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [378]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ImageToMask"}, "widgets_values": ["red"]}, {"id": 251, "type": "MaskToImage", "pos": [780, 3470], "size": [260, 70], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 372}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [373]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "MaskToImage"}, "widgets_values": []}, {"id": 199, "type": "GrowMask", "pos": [470, 3460], "size": [270, 82], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 279}, {"localized_name": "expand", "name": "expand", "type": "INT", "widget": {"name": "expand"}, "link": 379}, {"localized_name": "tapered_corners", "name": "tapered_corners", "type": "BOOLEAN", "widget": {"name": "tapered_corners"}, "link": null}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [372]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "GrowMask"}, "widgets_values": [20, true]}, {"id": 252, "type": "ImageBlur", "pos": [480, 3620], "size": [270, 82], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 373}, {"localized_name": "blur_radius", "name": "blur_radius", "type": "INT", "widget": {"name": "blur_radius"}, "link": 380}, {"localized_name": "sigma", "name": "sigma", "type": "FLOAT", "widget": {"name": "sigma"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [377]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ImageBlur"}, "widgets_values": [31, 1]}], "groups": [], "links": [{"id": 373, "origin_id": 251, "origin_slot": 0, "target_id": 252, "target_slot": 0, "type": "IMAGE"}, {"id": 377, "origin_id": 252, "origin_slot": 0, "target_id": 253, "target_slot": 0, "type": "IMAGE"}, {"id": 372, "origin_id": 199, "origin_slot": 0, "target_id": 251, "target_slot": 0, "type": "MASK"}, {"id": 279, "origin_id": -10, "origin_slot": 0, "target_id": 199, "target_slot": 0, "type": "MASK"}, {"id": 378, "origin_id": 253, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "MASK"}, {"id": 379, "origin_id": -10, "origin_slot": 1, "target_id": 199, "target_slot": 1, "type": "INT"}, {"id": 380, "origin_id": -10, "origin_slot": 2, "target_id": 252, "target_slot": 1, "type": "INT"}], "extra": {"workflowRendererVersion": "LG"}}, {"id": "2a4b2cc0-db37-4302-a067-da392f38f06b", "version": 1, "state": {"lastGroupId": 14, "lastNodeId": 255, "lastLinkId": 401, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Scale image and mask", "inputNode": {"id": -10, "bounding": [2110, 1406, 120, 100]}, "outputNode": {"id": -20, "bounding": [3320, 1406, 120, 80]}, "inputs": [{"id": "53ec80db-b075-446c-a79b-891d82ae3cf1", "name": "mask", "type": "MASK", "linkIds": [360], "localized_name": "mask", "pos": [2210, 1426]}, {"id": "37820e3d-f495-4b41-b0c6-58765a0c1766", "name": "image", "type": "IMAGE", "linkIds": [350], "localized_name": "image", "pos": [2210, 1446]}, {"id": "d388f5f1-7a36-4563-b104-9f7ec77f636d", "name": "value", "type": "INT", "linkIds": [365], "pos": [2210, 1466]}], "outputs": [{"id": "7ef75a31-2e69-4dce-8e13-76cd17b4c272", "name": "MASK", "type": "MASK", "linkIds": [364], "localized_name": "MASK", "pos": [3340, 1426]}, {"id": "36058145-b72c-4bd4-bb63-e2e22456d003", "name": "IMAGE", "type": "IMAGE", "linkIds": [352, 353, 354], "localized_name": "IMAGE", "pos": [3340, 1446]}], "widgets": [], "nodes": [{"id": 218, "type": "ImageToMask", "pos": [2990, 1540], "size": [270, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 363}, {"localized_name": "channel", "name": "channel", "type": "COMBO", "widget": {"name": "channel"}, "link": null}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [364]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.65", "Node name for S&R": "ImageToMask"}, "widgets_values": ["red"]}, {"id": 216, "type": "ImageScaleToMaxDimension", "pos": [2610, 1570], "size": [281.2027282714844, 82], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 361}, {"localized_name": "upscale_method", "name": "upscale_method", "type": "COMBO", "widget": {"name": "upscale_method"}, "link": null}, {"localized_name": "largest_size", "name": "largest_size", "type": "INT", "widget": {"name": "largest_size"}, "link": 362}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [363]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ImageScaleToMaxDimension"}, "widgets_values": ["area", 1536]}, {"id": 217, "type": "MaskToImage", "pos": [2700, 1420], "size": [193.2779296875, 26], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 360}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [361]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.65", "Node name for S&R": "MaskToImage"}, "widgets_values": []}, {"id": 194, "type": "ImageScaleToMaxDimension", "pos": [2590, 1280], "size": [281.2027282714844, 82], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 350}, {"localized_name": "upscale_method", "name": "upscale_method", "type": "COMBO", "widget": {"name": "upscale_method"}, "link": null}, {"localized_name": "largest_size", "name": "largest_size", "type": "INT", "widget": {"name": "largest_size"}, "link": 359}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [352, 353, 354]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ImageScaleToMaxDimension"}, "widgets_values": ["area", 1536]}, {"id": 215, "type": "PrimitiveInt", "pos": [2260, 1560], "size": [270, 82], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "value", "name": "value", "type": "INT", "widget": {"name": "value"}, "link": 365}], "outputs": [{"localized_name": "INT", "name": "INT", "type": "INT", "links": [359, 362]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.65", "Node name for S&R": "PrimitiveInt"}, "widgets_values": [1536, "fixed"]}], "groups": [], "links": [{"id": 363, "origin_id": 216, "origin_slot": 0, "target_id": 218, "target_slot": 0, "type": "IMAGE"}, {"id": 361, "origin_id": 217, "origin_slot": 0, "target_id": 216, "target_slot": 0, "type": "IMAGE"}, {"id": 362, "origin_id": 215, "origin_slot": 0, "target_id": 216, "target_slot": 2, "type": "INT"}, {"id": 359, "origin_id": 215, "origin_slot": 0, "target_id": 194, "target_slot": 2, "type": "INT"}, {"id": 360, "origin_id": -10, "origin_slot": 0, "target_id": 217, "target_slot": 0, "type": "MASK"}, {"id": 350, "origin_id": -10, "origin_slot": 1, "target_id": 194, "target_slot": 0, "type": "IMAGE"}, {"id": 364, "origin_id": 218, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "MASK"}, {"id": 352, "origin_id": 194, "origin_slot": 0, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 353, "origin_id": 194, "origin_slot": 0, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 354, "origin_id": 194, "origin_slot": 0, "target_id": -20, "target_slot": 1, "type": "IMAGE"}, {"id": 365, "origin_id": -10, "origin_slot": 2, "target_id": 215, "target_slot": 0, "type": "INT"}], "extra": {"workflowRendererVersion": "LG"}}]}, "config": {}, "extra": {"workflowRendererVersion": "LG", "ds": {"scale": 1.170393777345649, "offset": [-2589.3260157061272, -547.3616692627206]}}, "version": 0.4} diff --git a/blueprints/Image Upscale(Z-image-Turbo).json b/blueprints/Image Upscale(Z-image-Turbo).json new file mode 100644 index 000000000..a67d6a2d8 --- /dev/null +++ b/blueprints/Image Upscale(Z-image-Turbo).json @@ -0,0 +1 @@ +{"id": "bf8108f3-d857-46c9-aef5-0e8ad2a64bf5", "revision": 0, "last_node_id": 95, "last_link_id": 115, "nodes": [{"id": 87, "type": "dd15cfd3-cd53-428c-b3e2-33ed4ff8fa78", "pos": [960.6668984200231, 332.66676187423354], "size": [400, 469.9869791666667], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}, {"label": "upscale_model", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}, {"name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE_1", "name": "IMAGE_1", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["67", "text"], ["69", "seed"], ["69", "control_after_generate"], ["-1", "denoise"], ["-1", "unet_name"], ["-1", "clip_name"], ["-1", "vae_name"], ["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [null, null, null, 0.33, "z_image_turbo_bf16.safetensors", "qwen_3_4b.safetensors", "ae.safetensors", "RealESRGAN_x4plus.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "dd15cfd3-cd53-428c-b3e2-33ed4ff8fa78", "version": 1, "state": {"lastGroupId": 5, "lastNodeId": 95, "lastLinkId": 115, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Image Upscale(Z-image-Turbo)", "inputNode": {"id": -10, "bounding": [-150, 390, 125.224609375, 160]}, "outputNode": {"id": -20, "bounding": [2070, 490, 120, 60]}, "inputs": [{"id": "e9a14390-4f93-4065-8b02-323f999527c0", "name": "image", "type": "IMAGE", "linkIds": [86], "localized_name": "image", "pos": [-44.775390625, 410]}, {"id": "c5655e11-9531-4949-996c-958b5fe92085", "name": "unet_name", "type": "COMBO", "linkIds": [109], "pos": [-44.775390625, 430]}, {"id": "82576043-dd69-4604-b572-09fabb6e602d", "name": "clip_name", "type": "COMBO", "linkIds": [110], "pos": [-44.775390625, 450]}, {"id": "59e20fb5-cd61-4d4b-a1fd-15a90c7ba6c2", "name": "vae_name", "type": "COMBO", "linkIds": [111], "pos": [-44.775390625, 470]}, {"id": "adc35153-dc52-4bac-be7e-9da19471f441", "name": "model_name", "type": "COMBO", "linkIds": [112], "label": "upscale_model", "pos": [-44.775390625, 490]}, {"id": "c1b2f097-616e-4420-93c8-04eb79f4ba1e", "name": "denoise", "type": "FLOAT", "linkIds": [115], "pos": [-44.775390625, 510]}], "outputs": [{"id": "f138a0aa-489a-42e1-92f7-e3747688c94d", "name": "IMAGE_1", "type": "IMAGE", "linkIds": [97, 103], "localized_name": "IMAGE_1", "label": "IMAGE", "pos": [2090, 510]}], "widgets": [], "nodes": [{"id": 71, "type": "CLIPTextEncode", "pos": [648.333324162179, 398.3333435177784], "size": [491.6666666666667, 150], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 82}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [83]}], "title": "CLIP Text Encode (Negative Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#323", "bgcolor": "#535"}, {"id": 79, "type": "ImageUpscaleWithModel", "pos": [623.3333541162552, 714.9999406294688], "size": [233.5689453125, 60], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 87}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 88}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [92]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "ImageUpscaleWithModel"}, "widgets_values": []}, {"id": 80, "type": "VAEEncode", "pos": [1173.3330331592938, 631.6665944654844], "size": [187.5, 60], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 93}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 90}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [91]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "VAEEncode"}, "widgets_values": []}, {"id": 81, "type": "ImageScaleBy", "pos": [865.0000410901742, 714.9999828835583], "size": [225, 95.546875], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 92}, {"localized_name": "upscale_method", "name": "upscale_method", "type": "COMBO", "widget": {"name": "upscale_method"}, "link": null}, {"localized_name": "scale_by", "name": "scale_by", "type": "FLOAT", "widget": {"name": "scale_by"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [93]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "ImageScaleBy"}, "widgets_values": ["lanczos", 0.5]}, {"id": 66, "type": "UNETLoader", "pos": [280, -20], "size": [323.984375, 118.64583333333334], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 109}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [104]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "UNETLoader", "models": [{"name": "z_image_turbo_bf16.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/diffusion_models/z_image_turbo_bf16.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["z_image_turbo_bf16.safetensors", "default"]}, {"id": 62, "type": "CLIPLoader", "pos": [280, 140], "size": [323.984375, 150.65104166666669], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 110}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [78, 82]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPLoader", "models": [{"name": "qwen_3_4b.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/text_encoders/qwen_3_4b.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_3_4b.safetensors", "lumina2", "default"]}, {"id": 67, "type": "CLIPTextEncode", "pos": [650.621298596813, -33.81729273975067], "size": [491.9791666666667, 377.98177083333337], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 78}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [75]}], "title": "CLIP Text Encode (Positive Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["masterpiece, 8k"], "color": "#232", "bgcolor": "#353"}, {"id": 63, "type": "VAELoader", "pos": [280, 330], "size": [323.984375, 83.99739583333334], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 111}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [73, 90]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "VAELoader", "models": [{"name": "ae.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/vae/ae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ae.safetensors"]}, {"id": 76, "type": "UpscaleModelLoader", "pos": [264.07395879037364, 704.8118881098496], "size": [323.984375, 83.99739583333334], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 112}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [87]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}, {"id": 70, "type": "ModelSamplingAuraFlow", "pos": [1200, -50], "size": [371.9791666666667, 80.1171875], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 104}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [74]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "ModelSamplingAuraFlow", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3]}, {"id": 65, "type": "VAEDecode", "pos": [1610, -50], "size": [251.97916666666669, 72.13541666666667], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 72}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 73}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [97, 103]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 78, "type": "ImageScaleToTotalPixels", "pos": [260, 850], "size": [325, 122.21354166666667], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 86}, {"localized_name": "upscale_method", "name": "upscale_method", "type": "COMBO", "widget": {"name": "upscale_method"}, "link": null}, {"localized_name": "megapixels", "name": "megapixels", "type": "FLOAT", "widget": {"name": "megapixels"}, "link": null}, {"localized_name": "resolution_steps", "name": "resolution_steps", "type": "INT", "widget": {"name": "resolution_steps"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [88]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "ImageScaleToTotalPixels"}, "widgets_values": ["lanczos", 1, 1]}, {"id": 69, "type": "KSampler", "pos": [1200, 80], "size": [366.6666666666667, 474], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 74}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 75}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 83}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 91}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": 115}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [72]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "KSampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1098688918602660, "randomize", 5, 1, "dpmpp_2m_sde", "beta", 0.33]}], "groups": [{"id": 3, "title": "Prompt", "bounding": [640, -90, 508.64583333333337, 662.0666813520016], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 5, "title": "Models", "bounding": [260, -90, 344.6965254233087, 516.414685926878], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 104, "origin_id": 66, "origin_slot": 0, "target_id": 70, "target_slot": 0, "type": "MODEL"}, {"id": 82, "origin_id": 62, "origin_slot": 0, "target_id": 71, "target_slot": 0, "type": "CLIP"}, {"id": 87, "origin_id": 76, "origin_slot": 0, "target_id": 79, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 88, "origin_id": 78, "origin_slot": 0, "target_id": 79, "target_slot": 1, "type": "IMAGE"}, {"id": 93, "origin_id": 81, "origin_slot": 0, "target_id": 80, "target_slot": 0, "type": "IMAGE"}, {"id": 90, "origin_id": 63, "origin_slot": 0, "target_id": 80, "target_slot": 1, "type": "VAE"}, {"id": 92, "origin_id": 79, "origin_slot": 0, "target_id": 81, "target_slot": 0, "type": "IMAGE"}, {"id": 74, "origin_id": 70, "origin_slot": 0, "target_id": 69, "target_slot": 0, "type": "MODEL"}, {"id": 75, "origin_id": 67, "origin_slot": 0, "target_id": 69, "target_slot": 1, "type": "CONDITIONING"}, {"id": 83, "origin_id": 71, "origin_slot": 0, "target_id": 69, "target_slot": 2, "type": "CONDITIONING"}, {"id": 91, "origin_id": 80, "origin_slot": 0, "target_id": 69, "target_slot": 3, "type": "LATENT"}, {"id": 72, "origin_id": 69, "origin_slot": 0, "target_id": 65, "target_slot": 0, "type": "LATENT"}, {"id": 73, "origin_id": 63, "origin_slot": 0, "target_id": 65, "target_slot": 1, "type": "VAE"}, {"id": 78, "origin_id": 62, "origin_slot": 0, "target_id": 67, "target_slot": 0, "type": "CLIP"}, {"id": 86, "origin_id": -10, "origin_slot": 0, "target_id": 78, "target_slot": 0, "type": "IMAGE"}, {"id": 97, "origin_id": 65, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 103, "origin_id": 65, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 109, "origin_id": -10, "origin_slot": 1, "target_id": 66, "target_slot": 0, "type": "COMBO"}, {"id": 110, "origin_id": -10, "origin_slot": 2, "target_id": 62, "target_slot": 0, "type": "COMBO"}, {"id": 111, "origin_id": -10, "origin_slot": 3, "target_id": 63, "target_slot": 0, "type": "COMBO"}, {"id": 112, "origin_id": -10, "origin_slot": 4, "target_id": 76, "target_slot": 0, "type": "COMBO"}, {"id": 115, "origin_id": -10, "origin_slot": 5, "target_id": 69, "target_slot": 9, "type": "FLOAT"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image generation and editing/Enhance"}]}, "config": {}, "extra": {"workflowRendererVersion": "LG"}, "version": 0.4} diff --git a/blueprints/Image to Depth Map (Lotus).json b/blueprints/Image to Depth Map (Lotus).json new file mode 100644 index 000000000..5b3f7a1d6 --- /dev/null +++ b/blueprints/Image to Depth Map (Lotus).json @@ -0,0 +1 @@ +{"id": "6af0a6c1-0161-4528-8685-65776e838d44", "revision": 0, "last_node_id": 75, "last_link_id": 245, "nodes": [{"id": 75, "type": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf", "pos": [600, 830], "size": [400, 110], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": null}, {"label": "depth_intensity", "name": "sigma", "type": "FLOAT", "widget": {"name": "sigma"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["-1", "sigma"], ["-1", "unet_name"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [999.0000000000002, "lotus-depth-d-v1-1.safetensors", "vae-ft-mse-840000-ema-pruned.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "488652fd-6edf-4d06-8f9f-4d84d3a34eaf", "version": 1, "state": {"lastGroupId": 1, "lastNodeId": 75, "lastLinkId": 245, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Image to Depth Map (Lotus)", "inputNode": {"id": -10, "bounding": [-60, -172.61268043518066, 126.625, 120]}, "outputNode": {"id": -20, "bounding": [1650, -172.61268043518066, 120, 60]}, "inputs": [{"id": "3bdd30c3-4ec9-485a-814b-e7d39fb6b5cc", "name": "pixels", "type": "IMAGE", "linkIds": [37], "localized_name": "pixels", "pos": [46.625, -152.61268043518066]}, {"id": "f9a1017c-f4b9-43b4-94c2-41c088b3a492", "name": "sigma", "type": "FLOAT", "linkIds": [243], "label": "depth_intensity", "pos": [46.625, -132.61268043518066]}, {"id": "cb96b9fe-93e7-41cf-b27f-6d6dc3a1890b", "name": "unet_name", "type": "COMBO", "linkIds": [244], "pos": [46.625, -112.61268043518066]}, {"id": "42c8efad-1661-49c7-89b5-2b735b72424d", "name": "vae_name", "type": "COMBO", "linkIds": [245], "pos": [46.625, -92.61268043518066]}], "outputs": [{"id": "2ec278bd-0b66-4b30-9c5b-994d5f638214", "name": "IMAGE", "type": "IMAGE", "linkIds": [242], "localized_name": "IMAGE", "pos": [1670, -152.61268043518066]}], "widgets": [], "nodes": [{"id": 10, "type": "UNETLoader", "pos": [108.05555555555557, -253.05555555555557], "size": [254.93706597222226, 82], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 244}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [31, 241]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "UNETLoader", "models": [{"name": "lotus-depth-d-v1-1.safetensors", "url": "https://huggingface.co/Comfy-Org/lotus/resolve/main/lotus-depth-d-v1-1.safetensors", "directory": "diffusion_models"}], "widget_ue_connectable": {}}, "widgets_values": ["lotus-depth-d-v1-1.safetensors", "default"]}, {"id": 18, "type": "DisableNoise", "pos": [607.0641494069639, -268.33337840371513], "size": [175, 33.333333333333336], "flags": {}, "order": 0, "mode": 0, "inputs": [], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "slot_index": 0, "links": [237]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "DisableNoise", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 23, "type": "VAEEncode", "pos": [620, 160], "size": [175, 50], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 37}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 38}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [201]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAEEncode", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 21, "type": "KSamplerSelect", "pos": [610, -60], "size": [210, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "slot_index": 0, "links": [33]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "KSamplerSelect", "widget_ue_connectable": {}}, "widgets_values": ["euler"]}, {"id": 19, "type": "BasicGuider", "pos": [610, -170], "size": [175, 50], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 241}, {"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 238}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "slot_index": 0, "links": [27]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "BasicGuider", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 16, "type": "SamplerCustomAdvanced", "pos": [890, -130], "size": [295.99609375, 271.65798611111114], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 237}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 27}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 33}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 194}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 201}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "slot_index": 0, "links": [232]}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "slot_index": 1, "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "SamplerCustomAdvanced", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 28, "type": "SetFirstSigma", "pos": [620, 50], "size": [210, 58], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 66}, {"localized_name": "sigma", "name": "sigma", "type": "FLOAT", "widget": {"name": "sigma"}, "link": 243}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "slot_index": 0, "links": [194]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "SetFirstSigma", "widget_ue_connectable": {}}, "widgets_values": [999.0000000000002]}, {"id": 8, "type": "VAEDecode", "pos": [1210, -120], "size": [175, 50], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 232}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 240}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [35]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAEDecode", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 22, "type": "ImageInvert", "pos": [1200, -220], "size": [175, 33.333333333333336], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 35}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [242]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "ImageInvert", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 14, "type": "VAELoader", "pos": [120, -90], "size": [254.93706597222226, 58], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 245}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [38, 240]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAELoader", "models": [{"name": "vae-ft-mse-840000-ema-pruned.safetensors", "url": "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.safetensors", "directory": "vae"}], "widget_ue_connectable": {}}, "widgets_values": ["vae-ft-mse-840000-ema-pruned.safetensors"]}, {"id": 68, "type": "LotusConditioning", "pos": [400, -150], "size": [175, 33.333333333333336], "flags": {}, "order": 2, "mode": 0, "inputs": [], "outputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "slot_index": 0, "links": [238]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "LotusConditioning", "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 20, "type": "BasicScheduler", "pos": [170, 40], "size": [210, 106], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 31}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "slot_index": 0, "links": [66]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "BasicScheduler", "widget_ue_connectable": {}}, "widgets_values": ["normal", 1, 1]}], "groups": [], "links": [{"id": 232, "origin_id": 16, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "LATENT"}, {"id": 240, "origin_id": 14, "origin_slot": 0, "target_id": 8, "target_slot": 1, "type": "VAE"}, {"id": 237, "origin_id": 18, "origin_slot": 0, "target_id": 16, "target_slot": 0, "type": "NOISE"}, {"id": 27, "origin_id": 19, "origin_slot": 0, "target_id": 16, "target_slot": 1, "type": "GUIDER"}, {"id": 33, "origin_id": 21, "origin_slot": 0, "target_id": 16, "target_slot": 2, "type": "SAMPLER"}, {"id": 194, "origin_id": 28, "origin_slot": 0, "target_id": 16, "target_slot": 3, "type": "SIGMAS"}, {"id": 201, "origin_id": 23, "origin_slot": 0, "target_id": 16, "target_slot": 4, "type": "LATENT"}, {"id": 241, "origin_id": 10, "origin_slot": 0, "target_id": 19, "target_slot": 0, "type": "MODEL"}, {"id": 238, "origin_id": 68, "origin_slot": 0, "target_id": 19, "target_slot": 1, "type": "CONDITIONING"}, {"id": 31, "origin_id": 10, "origin_slot": 0, "target_id": 20, "target_slot": 0, "type": "MODEL"}, {"id": 35, "origin_id": 8, "origin_slot": 0, "target_id": 22, "target_slot": 0, "type": "IMAGE"}, {"id": 38, "origin_id": 14, "origin_slot": 0, "target_id": 23, "target_slot": 1, "type": "VAE"}, {"id": 66, "origin_id": 20, "origin_slot": 0, "target_id": 28, "target_slot": 0, "type": "SIGMAS"}, {"id": 37, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 242, "origin_id": 22, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 243, "origin_id": -10, "origin_slot": 1, "target_id": 28, "target_slot": 1, "type": "FLOAT"}, {"id": 244, "origin_id": -10, "origin_slot": 2, "target_id": 10, "target_slot": 0, "type": "COMBO"}, {"id": 245, "origin_id": -10, "origin_slot": 3, "target_id": 14, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image generation and editing/Depth to image"}]}, "config": {}, "extra": {"ds": {"scale": 1.3589709866044692, "offset": [-138.53613935617864, -786.0629126022195]}, "workflowRendererVersion": "LG"}, "version": 0.4} diff --git a/blueprints/Image to Layers(Qwen-Image Layered).json b/blueprints/Image to Layers(Qwen-Image Layered).json new file mode 100644 index 000000000..f4c7f0b5f --- /dev/null +++ b/blueprints/Image to Layers(Qwen-Image Layered).json @@ -0,0 +1 @@ +{"id": "1a761372-7c82-4016-b9bf-fa285967e1e9", "revision": 0, "last_node_id": 83, "last_link_id": 0, "nodes": [{"id": 83, "type": "f754a936-daaf-4b6e-9658-41fdc54d301d", "pos": [61.999827823554256, 153.3332507624185], "size": [400, 550], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": null}, {"name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"name": "layers", "type": "INT", "widget": {"name": "layers"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "steps"], ["-1", "cfg"], ["-1", "layers"], ["3", "seed"], ["3", "control_after_generate"]], "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["", 20, 2.5, 2]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "f754a936-daaf-4b6e-9658-41fdc54d301d", "version": 1, "state": {"lastGroupId": 3, "lastNodeId": 83, "lastLinkId": 159, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Image to Layers (Qwen-Image-Layered)", "inputNode": {"id": -10, "bounding": [-510, 523, 120, 140]}, "outputNode": {"id": -20, "bounding": [1160, 523, 120, 60]}, "inputs": [{"id": "6c36b5bc-c9a5-4b07-8b52-6fe0df434cce", "name": "image", "type": "IMAGE", "linkIds": [148, 149], "localized_name": "image", "pos": [-410, 543]}, {"id": "8497fe33-124d-4e3e-9ab6-fc4a56a98dde", "name": "text", "type": "STRING", "linkIds": [150], "pos": [-410, 563]}, {"id": "509ab2c1-e6da-47ba-8714-023100ab92bd", "name": "steps", "type": "INT", "linkIds": [153], "pos": [-410, 583]}, {"id": "dd81894e-5def-4c75-9b17-d8f89fe095d6", "name": "cfg", "type": "FLOAT", "linkIds": [154], "pos": [-410, 603]}, {"id": "66da7c8a-3369-4a3f-92f2-3073afc55e7d", "name": "layers", "type": "INT", "linkIds": [159], "pos": [-410, 623]}], "outputs": [{"id": "7df75921-6729-4aad-bfc1-fcc536c2d298", "name": "IMAGE", "type": "IMAGE", "linkIds": [110], "localized_name": "IMAGE", "pos": [1180, 543]}], "widgets": [], "nodes": [{"id": 38, "type": "CLIPLoader", "pos": [-320, 310], "size": [346.7470703125, 106], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "slot_index": 0, "links": [74, 75]}], "properties": {"Node name for S&R": "CLIPLoader", "cnr_id": "comfy-core", "ver": "0.5.1", "models": [{"name": "qwen_2.5_vl_7b_fp8_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/HunyuanVideo_1.5_repackaged/resolve/main/split_files/text_encoders/qwen_2.5_vl_7b_fp8_scaled.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_2.5_vl_7b_fp8_scaled.safetensors", "qwen_image", "default"]}, {"id": 39, "type": "VAELoader", "pos": [-320, 460], "size": [346.7470703125, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [76, 139]}], "properties": {"Node name for S&R": "VAELoader", "cnr_id": "comfy-core", "ver": "0.5.1", "models": [{"name": "qwen_image_layered_vae.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image-Layered_ComfyUI/resolve/main/split_files/vae/qwen_image_layered_vae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_image_layered_vae.safetensors"]}, {"id": 7, "type": "CLIPTextEncode", "pos": [70, 420], "size": [425.27801513671875, 180.6060791015625], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 75}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [131]}], "title": "CLIP Text Encode (Negative Prompt)", "properties": {"Node name for S&R": "CLIPTextEncode", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#322", "bgcolor": "#533"}, {"id": 70, "type": "ReferenceLatent", "pos": [330, 670], "size": [204.1666717529297, 46], "flags": {"collapsed": true}, "order": 9, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 131}, {"localized_name": "latent", "name": "latent", "shape": 7, "type": "LATENT", "link": 134}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [132]}], "properties": {"Node name for S&R": "ReferenceLatent", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 69, "type": "ReferenceLatent", "pos": [330, 710], "size": [204.1666717529297, 46], "flags": {"collapsed": true}, "order": 8, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 129}, {"localized_name": "latent", "name": "latent", "shape": 7, "type": "LATENT", "link": 133}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [130]}], "properties": {"Node name for S&R": "ReferenceLatent", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 66, "type": "ModelSamplingAuraFlow", "pos": [530, 150], "size": [270, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 126}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [125]}], "properties": {"Node name for S&R": "ModelSamplingAuraFlow", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1]}, {"id": 76, "type": "LatentCutToBatch", "pos": [830, 160], "size": [270, 82], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 142}, {"localized_name": "dim", "name": "dim", "type": "COMBO", "widget": {"name": "dim"}, "link": null}, {"localized_name": "slice_size", "name": "slice_size", "type": "INT", "widget": {"name": "slice_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [143]}], "properties": {"Node name for S&R": "LatentCutToBatch", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["t", 1]}, {"id": 71, "type": "VAEEncode", "pos": [100, 690], "size": [140, 46], "flags": {"collapsed": false}, "order": 10, "mode": 0, "inputs": [{"localized_name": "pixels", "name": "pixels", "type": "IMAGE", "link": 149}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 139}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [133, 134]}], "properties": {"Node name for S&R": "VAEEncode", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 8, "type": "VAEDecode", "pos": [850, 310], "size": [210, 46], "flags": {"collapsed": true}, "order": 7, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 143}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 76}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [110]}], "properties": {"Node name for S&R": "VAEDecode", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 6, "type": "CLIPTextEncode", "pos": [70, 180], "size": [422.84503173828125, 164.31304931640625], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 74}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 150}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [129]}], "title": "CLIP Text Encode (Positive Prompt)", "properties": {"Node name for S&R": "CLIPTextEncode", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 3, "type": "KSampler", "pos": [530, 280], "size": [270, 400], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 125}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 130}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 132}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 157}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": 153}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": 154}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [142]}], "properties": {"Node name for S&R": "KSampler", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize", 20, 2.5, "euler", "simple", 1]}, {"id": 78, "type": "GetImageSize", "pos": [80, 790], "size": [210, 136], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 148}], "outputs": [{"localized_name": "width", "name": "width", "type": "INT", "links": [155]}, {"localized_name": "height", "name": "height", "type": "INT", "links": [156]}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "links": null}], "properties": {"Node name for S&R": "GetImageSize", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 83, "type": "EmptyQwenImageLayeredLatentImage", "pos": [320, 790], "size": [330.9341796875, 130], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 155}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 156}, {"localized_name": "layers", "name": "layers", "type": "INT", "widget": {"name": "layers"}, "link": 159}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [157]}], "properties": {"Node name for S&R": "EmptyQwenImageLayeredLatentImage", "cnr_id": "comfy-core", "ver": "0.5.1", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [640, 640, 2, 1]}, {"id": 37, "type": "UNETLoader", "pos": [-320, 180], "size": [346.7470703125, 82], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [126]}], "properties": {"Node name for S&R": "UNETLoader", "cnr_id": "comfy-core", "ver": "0.5.1", "models": [{"name": "qwen_image_layered_bf16.safetensors", "url": "https://huggingface.co/Comfy-Org/Qwen-Image-Layered_ComfyUI/resolve/main/split_files/diffusion_models/qwen_image_layered_bf16.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_image_layered_bf16.safetensors", "default"]}], "groups": [{"id": 1, "title": "Prompt(Optional)", "bounding": [60, 110, 450, 510], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Load Models", "bounding": [-330, 110, 366.7470703125, 421.6], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 75, "origin_id": 38, "origin_slot": 0, "target_id": 7, "target_slot": 0, "type": "CLIP"}, {"id": 131, "origin_id": 7, "origin_slot": 0, "target_id": 70, "target_slot": 0, "type": "CONDITIONING"}, {"id": 134, "origin_id": 71, "origin_slot": 0, "target_id": 70, "target_slot": 1, "type": "LATENT"}, {"id": 129, "origin_id": 6, "origin_slot": 0, "target_id": 69, "target_slot": 0, "type": "CONDITIONING"}, {"id": 133, "origin_id": 71, "origin_slot": 0, "target_id": 69, "target_slot": 1, "type": "LATENT"}, {"id": 126, "origin_id": 37, "origin_slot": 0, "target_id": 66, "target_slot": 0, "type": "MODEL"}, {"id": 125, "origin_id": 66, "origin_slot": 0, "target_id": 3, "target_slot": 0, "type": "MODEL"}, {"id": 130, "origin_id": 69, "origin_slot": 0, "target_id": 3, "target_slot": 1, "type": "CONDITIONING"}, {"id": 132, "origin_id": 70, "origin_slot": 0, "target_id": 3, "target_slot": 2, "type": "CONDITIONING"}, {"id": 142, "origin_id": 3, "origin_slot": 0, "target_id": 76, "target_slot": 0, "type": "LATENT"}, {"id": 74, "origin_id": 38, "origin_slot": 0, "target_id": 6, "target_slot": 0, "type": "CLIP"}, {"id": 139, "origin_id": 39, "origin_slot": 0, "target_id": 71, "target_slot": 1, "type": "VAE"}, {"id": 143, "origin_id": 76, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "LATENT"}, {"id": 76, "origin_id": 39, "origin_slot": 0, "target_id": 8, "target_slot": 1, "type": "VAE"}, {"id": 148, "origin_id": -10, "origin_slot": 0, "target_id": 78, "target_slot": 0, "type": "IMAGE"}, {"id": 149, "origin_id": -10, "origin_slot": 0, "target_id": 71, "target_slot": 0, "type": "IMAGE"}, {"id": 110, "origin_id": 8, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 150, "origin_id": -10, "origin_slot": 1, "target_id": 6, "target_slot": 1, "type": "STRING"}, {"id": 153, "origin_id": -10, "origin_slot": 2, "target_id": 3, "target_slot": 5, "type": "INT"}, {"id": 154, "origin_id": -10, "origin_slot": 3, "target_id": 3, "target_slot": 6, "type": "FLOAT"}, {"id": 155, "origin_id": 78, "origin_slot": 0, "target_id": 83, "target_slot": 0, "type": "INT"}, {"id": 156, "origin_id": 78, "origin_slot": 1, "target_id": 83, "target_slot": 1, "type": "INT"}, {"id": 157, "origin_id": 83, "origin_slot": 0, "target_id": 3, "target_slot": 3, "type": "LATENT"}, {"id": 159, "origin_id": -10, "origin_slot": 4, "target_id": 83, "target_slot": 2, "type": "INT"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image generation and editing/Image to layers"}]}, "config": {}, "extra": {"ds": {"scale": 1.14, "offset": [695.5933739308316, 6.855893974423647]}, "workflowRendererVersion": "LG"}, "version": 0.4} diff --git a/blueprints/Image to Model (Hunyuan3d 2.1).json b/blueprints/Image to Model (Hunyuan3d 2.1).json new file mode 100644 index 000000000..04b2d9bc9 --- /dev/null +++ b/blueprints/Image to Model (Hunyuan3d 2.1).json @@ -0,0 +1 @@ +{"id": "8fe311ec-2147-47a8-b618-7bd6fb6d4f9d", "revision": 0, "last_node_id": 23, "last_link_id": 24, "nodes": [{"id": 19, "type": "feb7d184-edf3-4851-9fd6-57a92c00ec42", "pos": [277.7327250391088, 256.4066470374603], "size": [340, 70], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": null}, {"name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": null}], "outputs": [{"localized_name": "MESH", "name": "MESH", "type": "MESH", "links": []}], "properties": {"proxyWidgets": [["-1", "ckpt_name"]], "cnr_id": "comfy-core", "ver": "0.3.65"}, "widgets_values": ["hunyuan_3d_v2.1.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "feb7d184-edf3-4851-9fd6-57a92c00ec42", "version": 1, "state": {"lastGroupId": 2, "lastNodeId": 23, "lastLinkId": 24, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Image to Model (Hunyuan3d 2.1)", "inputNode": {"id": -10, "bounding": [-138.94803619384766, -392.62060546875, 120, 80]}, "outputNode": {"id": -20, "bounding": [1090, -310, 120, 60]}, "inputs": [{"id": "ab9b5b83-88f9-4698-954d-93f644bd07aa", "name": "image", "type": "IMAGE", "linkIds": [21], "localized_name": "image", "pos": [-38.948036193847656, -372.62060546875]}, {"id": "e15b0ba4-b5fe-41eb-9266-006ce1f1cf79", "name": "ckpt_name", "type": "COMBO", "linkIds": [23], "pos": [-38.948036193847656, -352.62060546875]}], "outputs": [{"id": "c8744662-e812-49b3-8bc8-744d557db6d6", "name": "MESH", "type": "MESH", "linkIds": [11], "localized_name": "MESH", "pos": [1110, -290]}], "widgets": [], "nodes": [{"id": 7, "type": "KSampler", "pos": [760, -510], "size": [270, 262], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 19}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 5}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 6}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 7}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [8]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "KSampler"}, "widgets_values": [894796671366012, "randomize", 30, 5, "euler", "normal", 1]}, {"id": 13, "type": "CLIPVisionEncode", "pos": [450, -410], "size": [270, 80], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "clip_vision", "name": "clip_vision", "type": "CLIP_VISION", "link": 20}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 21}, {"localized_name": "crop", "name": "crop", "type": "COMBO", "widget": {"name": "crop"}, "link": null}], "outputs": [{"localized_name": "CLIP_VISION_OUTPUT", "name": "CLIP_VISION_OUTPUT", "type": "CLIP_VISION_OUTPUT", "links": [22]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "CLIPVisionEncode"}, "widgets_values": ["center"]}, {"id": 6, "type": "Hunyuan3Dv2Conditioning", "pos": [510, -280], "size": [217.82578125, 46], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "clip_vision_output", "name": "clip_vision_output", "type": "CLIP_VISION_OUTPUT", "link": 22}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [5]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [6]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "Hunyuan3Dv2Conditioning"}, "widgets_values": []}, {"id": 4, "type": "EmptyLatentHunyuan3Dv2", "pos": [450, -180], "size": [270, 82], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "resolution", "name": "resolution", "type": "INT", "widget": {"name": "resolution"}, "link": null}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [7]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "EmptyLatentHunyuan3Dv2"}, "widgets_values": [4096, 1]}, {"id": 9, "type": "VoxelToMesh", "pos": [760, -40], "size": [270, 82], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "voxel", "name": "voxel", "type": "VOXEL", "link": 10}, {"localized_name": "algorithm", "name": "algorithm", "type": "COMBO", "widget": {"name": "algorithm"}, "link": null}, {"localized_name": "threshold", "name": "threshold", "type": "FLOAT", "widget": {"name": "threshold"}, "link": null}], "outputs": [{"localized_name": "MESH", "name": "MESH", "type": "MESH", "links": [11]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "VoxelToMesh"}, "widgets_values": ["surface net", 0.6]}, {"id": 8, "type": "VAEDecodeHunyuan3D", "pos": [760, -200], "size": [270, 102], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 8}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 18}, {"localized_name": "num_chunks", "name": "num_chunks", "type": "INT", "widget": {"name": "num_chunks"}, "link": null}, {"localized_name": "octree_resolution", "name": "octree_resolution", "type": "INT", "widget": {"name": "octree_resolution"}, "link": null}], "outputs": [{"localized_name": "VOXEL", "name": "VOXEL", "type": "VOXEL", "links": [10]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "VAEDecodeHunyuan3D"}, "widgets_values": [8000, 256]}, {"id": 1, "type": "ImageOnlyCheckpointLoader", "pos": [60, -510], "size": [356.0005859375, 100], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 23}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [16]}, {"localized_name": "CLIP_VISION", "name": "CLIP_VISION", "type": "CLIP_VISION", "links": [20]}, {"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [18]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ImageOnlyCheckpointLoader", "models": [{"name": "hunyuan_3d_v2.1.safetensors", "url": "https://huggingface.co/Comfy-Org/hunyuan3D_2.1_repackaged/resolve/main/hunyuan_3d_v2.1.safetensors", "directory": "checkpoints"}]}, "widgets_values": ["hunyuan_3d_v2.1.safetensors"]}, {"id": 3, "type": "ModelSamplingAuraFlow", "pos": [450, -510], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 16}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [19]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.59", "Node name for S&R": "ModelSamplingAuraFlow"}, "widgets_values": [1]}], "groups": [], "links": [{"id": 16, "origin_id": 1, "origin_slot": 0, "target_id": 3, "target_slot": 0, "type": "MODEL"}, {"id": 19, "origin_id": 3, "origin_slot": 0, "target_id": 7, "target_slot": 0, "type": "MODEL"}, {"id": 5, "origin_id": 6, "origin_slot": 0, "target_id": 7, "target_slot": 1, "type": "CONDITIONING"}, {"id": 6, "origin_id": 6, "origin_slot": 1, "target_id": 7, "target_slot": 2, "type": "CONDITIONING"}, {"id": 7, "origin_id": 4, "origin_slot": 0, "target_id": 7, "target_slot": 3, "type": "LATENT"}, {"id": 8, "origin_id": 7, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "LATENT"}, {"id": 18, "origin_id": 1, "origin_slot": 2, "target_id": 8, "target_slot": 1, "type": "VAE"}, {"id": 10, "origin_id": 8, "origin_slot": 0, "target_id": 9, "target_slot": 0, "type": "VOXEL"}, {"id": 20, "origin_id": 1, "origin_slot": 1, "target_id": 13, "target_slot": 0, "type": "CLIP_VISION"}, {"id": 22, "origin_id": 13, "origin_slot": 0, "target_id": 6, "target_slot": 0, "type": "CLIP_VISION_OUTPUT"}, {"id": 21, "origin_id": -10, "origin_slot": 0, "target_id": 13, "target_slot": 1, "type": "IMAGE"}, {"id": 11, "origin_id": 9, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "MESH"}, {"id": 23, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "3D/Image to 3D Model"}]}, "config": {}, "extra": {"ds": {"scale": 0.620921323059155, "offset": [1636.2881100217016, 965.23503257945]}, "workflowRendererVersion": "LG"}, "version": 0.4} diff --git a/blueprints/Image to Video (Wan 2.2).json b/blueprints/Image to Video (Wan 2.2).json new file mode 100644 index 000000000..cd0b44a72 --- /dev/null +++ b/blueprints/Image to Video (Wan 2.2).json @@ -0,0 +1 @@ +{"id": "ec7da562-7e21-4dac-a0d2-f4441e1efd3b", "revision": 0, "last_node_id": 119, "last_link_id": 231, "nodes": [{"id": 116, "type": "296b573f-1e7d-43df-a2df-925fe5e17063", "pos": [1098.3332694531493, -268.3334707134305], "size": [400, 470], "flags": {"collapsed": false}, "order": 0, "mode": 0, "inputs": [{"label": "start image", "localized_name": "start_image", "name": "start_image", "type": "IMAGE", "link": null}, {"label": "prompt", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"name": "width", "type": "INT", "widget": {"name": "width"}, "link": null}, {"name": "height", "type": "INT", "widget": {"name": "height"}, "link": null}, {"name": "length", "type": "INT", "widget": {"name": "length"}, "link": null}, {"label": "low_noise_unet", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"label": "low_noise_lora", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": null}, {"label": "high_noise_unet", "name": "unet_name_1", "type": "COMBO", "widget": {"name": "unet_name_1"}, "link": null}, {"label": "high_noise_lora", "name": "lora_name_1", "type": "COMBO", "widget": {"name": "lora_name_1"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"name": "VIDEO", "type": "VIDEO", "links": null}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "width"], ["-1", "height"], ["-1", "length"], ["86", "noise_seed"], ["86", "control_after_generate"], ["-1", "unet_name"], ["-1", "lora_name"], ["-1", "unet_name_1"], ["-1", "lora_name_1"], ["-1", "clip_name"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.11.0"}, "widgets_values": ["", 640, 640, 81, null, null, "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors", "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors", "umt5_xxl_fp8_e4m3fn_scaled.safetensors", "wan_2.1_vae.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "296b573f-1e7d-43df-a2df-925fe5e17063", "version": 1, "state": {"lastGroupId": 16, "lastNodeId": 119, "lastLinkId": 231, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Image to Video (Wan 2.2)", "inputNode": {"id": -10, "bounding": [-250, 570, 131.435546875, 260]}, "outputNode": {"id": -20, "bounding": [1723.4786916118696, 716.3650158766799, 120, 60]}, "inputs": [{"id": "69d8b033-5601-446e-9634-f5cafbd373e2", "name": "start_image", "type": "IMAGE", "linkIds": [186], "localized_name": "start_image", "label": "start image", "shape": 7, "pos": [-138.564453125, 590]}, {"id": "88ae2af6-63c1-41be-90e8-6359f4d5f133", "name": "text", "type": "STRING", "linkIds": [222], "label": "prompt", "pos": [-138.564453125, 610]}, {"id": "fad9d346-653e-4be5-9e52-38cef6fa59f3", "name": "width", "type": "INT", "linkIds": [223], "pos": [-138.564453125, 630]}, {"id": "a4f34897-8063-4613-a2eb-6c2503167eb1", "name": "height", "type": "INT", "linkIds": [224], "pos": [-138.564453125, 650]}, {"id": "dc4d4472-cff7-41e0-9a4a-d118fcd4a21a", "name": "length", "type": "INT", "linkIds": [225], "pos": [-138.564453125, 670]}, {"id": "f7317e79-4a52-460b-9d71-89ec450dc333", "name": "unet_name", "type": "COMBO", "linkIds": [226], "label": "low_noise_unet", "pos": [-138.564453125, 690]}, {"id": "7a470f86-503a-474f-9571-830c8eb99231", "name": "lora_name", "type": "COMBO", "linkIds": [227], "label": "low_noise_lora", "pos": [-138.564453125, 710]}, {"id": "1d88c531-f68e-41b9-95c5-16f944a55b7d", "name": "unet_name_1", "type": "COMBO", "linkIds": [228], "label": "high_noise_unet", "pos": [-138.564453125, 730]}, {"id": "67a79742-33e5-4c38-89d8-ecb021d067c8", "name": "lora_name_1", "type": "COMBO", "linkIds": [229], "label": "high_noise_lora", "pos": [-138.564453125, 750]}, {"id": "9d184b83-37c6-4891-bbdf-ffcdf5ab2016", "name": "clip_name", "type": "COMBO", "linkIds": [230], "pos": [-138.564453125, 770]}, {"id": "24c568ec-aeb2-4c31-9f87-54ee9099d55f", "name": "vae_name", "type": "COMBO", "linkIds": [231], "pos": [-138.564453125, 790]}], "outputs": [{"id": "994c9c48-5f35-48ed-8c9d-0f2b21990cb6", "name": "VIDEO", "type": "VIDEO", "linkIds": [221], "pos": [1743.4786916118696, 736.3650158766799]}], "widgets": [], "nodes": [{"id": 84, "type": "CLIPLoader", "pos": [59.999957705045404, 29.99977085410412], "size": [346.38020833333337, 106], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 230}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "slot_index": 0, "links": [178, 181]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "CLIPLoader", "models": [{"name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", "directory": "text_encoders"}], "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["umt5_xxl_fp8_e4m3fn_scaled.safetensors", "wan", "default"]}, {"id": 90, "type": "VAELoader", "pos": [59.999957705045404, 189.9997708925786], "size": [344.7265625, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 231}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [176, 185]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "VAELoader", "models": [{"name": "wan_2.1_vae.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors", "directory": "vae"}], "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["wan_2.1_vae.safetensors"]}, {"id": 95, "type": "UNETLoader", "pos": [49.99996468306838, -230.00013148243067], "size": [346.7447916666667, 82], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 226}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [194]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "UNETLoader", "models": [{"name": "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", "directory": "diffusion_models"}], "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", "default"]}, {"id": 96, "type": "UNETLoader", "pos": [49.99996468306838, -100.00008258817711], "size": [346.7447916666667, 82], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 228}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [196]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "UNETLoader", "models": [{"name": "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", "directory": "diffusion_models"}], "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", "default"]}, {"id": 103, "type": "ModelSamplingSD3", "pos": [739.9998741034308, -100.00008258817711], "size": [210, 58], "flags": {"collapsed": false}, "order": 12, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 189}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [192]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "ModelSamplingSD3", "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": [5.000000000000001]}, {"id": 93, "type": "CLIPTextEncode", "pos": [439.99997175727736, 89.99984067280784], "size": [510, 88], "flags": {}, "order": 16, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 181}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 222}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [183]}], "title": "CLIP Text Encode (Positive Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "CLIPTextEncode", "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 89, "type": "CLIPTextEncode", "pos": [439.99997175727736, 289.99986864261126], "size": [510, 88], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 178}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [184]}], "title": "CLIP Text Encode (Negative Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "CLIPTextEncode", "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"], "color": "#322", "bgcolor": "#533"}, {"id": 101, "type": "LoraLoaderModelOnly", "pos": [449.99996477925447, -230.00013148243067], "size": [280, 82], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 194}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 227}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [190]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.49", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/loras/wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors", "directory": "loras"}], "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors", 1.0000000000000002]}, {"id": 102, "type": "LoraLoaderModelOnly", "pos": [449.99996477925447, -100.00008258817711], "size": [280, 82], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 196}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 229}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [189]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.49", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/loras/wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors", "directory": "loras"}], "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors", 1.0000000000000002]}, {"id": 104, "type": "ModelSamplingSD3", "pos": [739.9998741034308, -230.00013148243067], "size": [210, 58], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 190}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [195]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "ModelSamplingSD3", "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": [5.000000000000001]}, {"id": 98, "type": "WanImageToVideo", "pos": [530.0000206419123, 529.9999245437435], "size": [342.59114583333337, 210], "flags": {}, "order": 17, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 183}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 184}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 185}, {"localized_name": "clip_vision_output", "name": "clip_vision_output", "shape": 7, "type": "CLIP_VISION_OUTPUT", "link": null}, {"localized_name": "start_image", "name": "start_image", "shape": 7, "type": "IMAGE", "link": 186}, {"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 223}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 224}, {"localized_name": "length", "name": "length", "type": "INT", "widget": {"name": "length"}, "link": 225}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "slot_index": 0, "links": [168, 172]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "slot_index": 1, "links": [169, 173]}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "slot_index": 2, "links": [174]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "WanImageToVideo", "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": [640, 640, 81, 1]}, {"id": 86, "type": "KSamplerAdvanced", "pos": [989.9999230265402, -250.00014544809514], "size": [304.73958333333337, 334], "flags": {}, "order": 14, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 195}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 172}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 173}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 174}, {"localized_name": "add_noise", "name": "add_noise", "type": "COMBO", "widget": {"name": "add_noise"}, "link": null}, {"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "start_at_step", "name": "start_at_step", "type": "INT", "widget": {"name": "start_at_step"}, "link": null}, {"localized_name": "end_at_step", "name": "end_at_step", "type": "INT", "widget": {"name": "end_at_step"}, "link": null}, {"localized_name": "return_with_leftover_noise", "name": "return_with_leftover_noise", "type": "COMBO", "widget": {"name": "return_with_leftover_noise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [170]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "KSamplerAdvanced", "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["enable", 0, "randomize", 4, 1, "euler", "simple", 0, 2, "enable"]}, {"id": 85, "type": "KSamplerAdvanced", "pos": [1336.748028098344, -250.00014544809514], "size": [304.73958333333337, 334], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 192}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 168}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 169}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 170}, {"localized_name": "add_noise", "name": "add_noise", "type": "COMBO", "widget": {"name": "add_noise"}, "link": null}, {"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "start_at_step", "name": "start_at_step", "type": "INT", "widget": {"name": "start_at_step"}, "link": null}, {"localized_name": "end_at_step", "name": "end_at_step", "type": "INT", "widget": {"name": "end_at_step"}, "link": null}, {"localized_name": "return_with_leftover_noise", "name": "return_with_leftover_noise", "type": "COMBO", "widget": {"name": "return_with_leftover_noise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [175]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "KSamplerAdvanced", "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["disable", 0, "fixed", 4, 1, "euler", "simple", 2, 4, "disable"]}, {"id": 67, "type": "Note", "pos": [510.0000345979581, 819.9999455547611], "size": [390, 88], "flags": {}, "order": 4, "mode": 0, "inputs": [], "outputs": [], "title": "Video Size", "properties": {"ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["By default, we set the video to a smaller size for users with low VRAM. If you have enough VRAM, you can change the size"], "color": "#222", "bgcolor": "#000"}, {"id": 105, "type": "MarkdownNote", "pos": [-469.9999795985529, 279.9998197772136], "size": [480, 170.65104166666669], "flags": {}, "order": 5, "mode": 0, "inputs": [], "outputs": [], "title": "VRAM Usage", "properties": {"ue_properties": {"version": "7.1", "widget_ue_connectable": {}, "input_ue_unconnectable": {}}}, "widgets_values": ["## GPU:RTX4090D 24GB\n\n| Model | Size |VRAM Usage | 1st Generation | 2nd Generation |\n|---------------------|-------|-----------|---------------|-----------------|\n| fp8_scaled |640*640| 84% | ≈ 536s | ≈ 513s |\n| fp8_scaled + 4steps LoRA | 640*640 | 83% | ≈ 97s | ≈ 71s |"], "color": "#222", "bgcolor": "#000"}, {"id": 66, "type": "MarkdownNote", "pos": [-469.9999795985529, -320.00012452364496], "size": [480, 572.1354166666667], "flags": {}, "order": 6, "mode": 0, "inputs": [], "outputs": [], "title": "Model Links", "properties": {"ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["[Tutorial](https://docs.comfy.org/tutorials/video/wan/wan2_2\n)\n\n**Diffusion Model**\n- [wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors)\n- [wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors)\n\n**LoRA**\n- [wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/loras/wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors)\n- [wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/loras/wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors)\n\n**VAE**\n- [wan_2.1_vae.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors)\n\n**Text Encoder** \n- [umt5_xxl_fp8_e4m3fn_scaled.safetensors](https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors)\n\n\nFile save location\n\n```\nComfyUI/\n├───📂 models/\n│ ├───📂 diffusion_models/\n│ │ ├─── wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors\n│ │ └─── wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors\n│ ├───📂 loras/\n│ │ ├─── wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors\n│ │ └─── wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors\n│ ├───📂 text_encoders/\n│ │ └─── umt5_xxl_fp8_e4m3fn_scaled.safetensors \n│ └───📂 vae/\n│ └── wan_2.1_vae.safetensors\n```\n"], "color": "#222", "bgcolor": "#000"}, {"id": 115, "type": "Note", "pos": [29.999978639114225, -470.00010361843204], "size": [360, 88], "flags": {}, "order": 7, "mode": 0, "inputs": [], "outputs": [], "title": "About 4 Steps LoRA", "properties": {"ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": ["Using the Wan2.2 Lighting LoRA will result in the loss of video dynamics, but it will reduce the generation time. This template provides two workflows, and you can enable one as needed."], "color": "#222", "bgcolor": "#000"}, {"id": 117, "type": "CreateVideo", "pos": [1030, 650], "size": [270, 78], "flags": {}, "order": 18, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 220}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [221]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [16]}, {"id": 87, "type": "VAEDecode", "pos": [1020, 540], "size": [210, 46], "flags": {}, "order": 15, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 175}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 176}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [220]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "VAEDecode", "ue_properties": {"widget_ue_connectable": {}, "version": "7.1", "input_ue_unconnectable": {}}}, "widgets_values": []}], "groups": [{"id": 15, "title": "fp8_scaled + 4steps LoRA", "bounding": [30, -350, 1630, 1120], "color": "#444", "font_size": 24, "flags": {}}, {"id": 11, "title": "Step1 - Load models", "bounding": [40, -310, 371.0310363769531, 571.3974609375], "color": "#444", "font_size": 24, "flags": {}}, {"id": 13, "title": "Step4 - Prompt", "bounding": [430, 20, 530, 420], "color": "#444", "font_size": 24, "flags": {}}, {"id": 14, "title": "Step3 - Video size & length", "bounding": [430, 460, 530, 290], "color": "#444", "font_size": 24, "flags": {}}, {"id": 16, "title": "Lightx2v 4steps LoRA", "bounding": [430, -310, 530, 310], "color": "#444", "font_size": 24, "flags": {}}], "links": [{"id": 189, "origin_id": 102, "origin_slot": 0, "target_id": 103, "target_slot": 0, "type": "MODEL"}, {"id": 181, "origin_id": 84, "origin_slot": 0, "target_id": 93, "target_slot": 0, "type": "CLIP"}, {"id": 178, "origin_id": 84, "origin_slot": 0, "target_id": 89, "target_slot": 0, "type": "CLIP"}, {"id": 194, "origin_id": 95, "origin_slot": 0, "target_id": 101, "target_slot": 0, "type": "MODEL"}, {"id": 196, "origin_id": 96, "origin_slot": 0, "target_id": 102, "target_slot": 0, "type": "MODEL"}, {"id": 190, "origin_id": 101, "origin_slot": 0, "target_id": 104, "target_slot": 0, "type": "MODEL"}, {"id": 183, "origin_id": 93, "origin_slot": 0, "target_id": 98, "target_slot": 0, "type": "CONDITIONING"}, {"id": 184, "origin_id": 89, "origin_slot": 0, "target_id": 98, "target_slot": 1, "type": "CONDITIONING"}, {"id": 185, "origin_id": 90, "origin_slot": 0, "target_id": 98, "target_slot": 2, "type": "VAE"}, {"id": 175, "origin_id": 85, "origin_slot": 0, "target_id": 87, "target_slot": 0, "type": "LATENT"}, {"id": 176, "origin_id": 90, "origin_slot": 0, "target_id": 87, "target_slot": 1, "type": "VAE"}, {"id": 195, "origin_id": 104, "origin_slot": 0, "target_id": 86, "target_slot": 0, "type": "MODEL"}, {"id": 172, "origin_id": 98, "origin_slot": 0, "target_id": 86, "target_slot": 1, "type": "CONDITIONING"}, {"id": 173, "origin_id": 98, "origin_slot": 1, "target_id": 86, "target_slot": 2, "type": "CONDITIONING"}, {"id": 174, "origin_id": 98, "origin_slot": 2, "target_id": 86, "target_slot": 3, "type": "LATENT"}, {"id": 192, "origin_id": 103, "origin_slot": 0, "target_id": 85, "target_slot": 0, "type": "MODEL"}, {"id": 168, "origin_id": 98, "origin_slot": 0, "target_id": 85, "target_slot": 1, "type": "CONDITIONING"}, {"id": 169, "origin_id": 98, "origin_slot": 1, "target_id": 85, "target_slot": 2, "type": "CONDITIONING"}, {"id": 170, "origin_id": 86, "origin_slot": 0, "target_id": 85, "target_slot": 3, "type": "LATENT"}, {"id": 186, "origin_id": -10, "origin_slot": 0, "target_id": 98, "target_slot": 4, "type": "IMAGE"}, {"id": 220, "origin_id": 87, "origin_slot": 0, "target_id": 117, "target_slot": 0, "type": "IMAGE"}, {"id": 221, "origin_id": 117, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 222, "origin_id": -10, "origin_slot": 1, "target_id": 93, "target_slot": 1, "type": "STRING"}, {"id": 223, "origin_id": -10, "origin_slot": 2, "target_id": 98, "target_slot": 5, "type": "INT"}, {"id": 224, "origin_id": -10, "origin_slot": 3, "target_id": 98, "target_slot": 6, "type": "INT"}, {"id": 225, "origin_id": -10, "origin_slot": 4, "target_id": 98, "target_slot": 7, "type": "INT"}, {"id": 226, "origin_id": -10, "origin_slot": 5, "target_id": 95, "target_slot": 0, "type": "COMBO"}, {"id": 227, "origin_id": -10, "origin_slot": 6, "target_id": 101, "target_slot": 1, "type": "COMBO"}, {"id": 228, "origin_id": -10, "origin_slot": 7, "target_id": 96, "target_slot": 0, "type": "COMBO"}, {"id": 229, "origin_id": -10, "origin_slot": 8, "target_id": 102, "target_slot": 1, "type": "COMBO"}, {"id": 230, "origin_id": -10, "origin_slot": 9, "target_id": 84, "target_slot": 0, "type": "COMBO"}, {"id": 231, "origin_id": -10, "origin_slot": 10, "target_id": 90, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Image to video"}]}, "config": {}, "extra": {"ds": {"scale": 0.7926047855889957, "offset": [-30.12529469925767, 690.3829855122884]}, "frontendVersion": "1.37.11", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true, "ue_links": []}, "version": 0.4} diff --git a/blueprints/Pose to Image (Z-Image-Turbo).json b/blueprints/Pose to Image (Z-Image-Turbo).json new file mode 100644 index 000000000..f4c224249 --- /dev/null +++ b/blueprints/Pose to Image (Z-Image-Turbo).json @@ -0,0 +1 @@ +{"id": "e046dd74-e2a7-4f31-a75b-5e11a8c72d4e", "revision": 0, "last_node_id": 26, "last_link_id": 46, "nodes": [{"id": 13, "type": "d8492a46-9e6c-4917-b5ea-4273aabf5f51", "pos": [400, 3630], "size": [400, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "image", "name": "image", "type": "IMAGE", "link": null}, {"label": "prompt", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}, {"name": "name", "type": "COMBO", "widget": {"name": "name"}, "link": null}], "outputs": [{"name": "IMAGE", "type": "IMAGE", "links": null}], "properties": {"proxyWidgets": [["-1", "text"], ["19", "seed"], ["19", "control_after_generate"], ["-1", "unet_name"], ["-1", "clip_name"], ["-1", "vae_name"], ["-1", "name"]], "cnr_id": "comfy-core", "ver": "0.11.0"}, "widgets_values": ["", null, null, "z_image_turbo_bf16.safetensors", "qwen_3_4b.safetensors", "ae.safetensors", "Z-Image-Turbo-Fun-Controlnet-Union.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "d8492a46-9e6c-4917-b5ea-4273aabf5f51", "version": 1, "state": {"lastGroupId": 3, "lastNodeId": 26, "lastLinkId": 46, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Pose to Image (Z-Image-Turbo)", "inputNode": {"id": -10, "bounding": [27.60368520069494, 4936.043696127976, 120, 160]}, "outputNode": {"id": -20, "bounding": [1598.6038576146689, 4936.043696127976, 120, 60]}, "inputs": [{"id": "29ca271b-8f63-4e7b-a4b8-c9b4192ada0b", "name": "image", "type": "IMAGE", "linkIds": [41, 42], "label": "image", "pos": [127.60368520069494, 4956.043696127976]}, {"id": "b6549f90-39ee-4b79-9e00-af4d9df969fe", "name": "text", "type": "STRING", "linkIds": [37], "label": "prompt", "pos": [127.60368520069494, 4976.043696127976]}, {"id": "9f23df20-75de-4782-8ff7-225bc7976bbe", "name": "unet_name", "type": "COMBO", "linkIds": [43], "pos": [127.60368520069494, 4996.043696127976]}, {"id": "fc8aa3eb-a537-4976-8b5f-666f0dc5af4b", "name": "clip_name", "type": "COMBO", "linkIds": [44], "pos": [127.60368520069494, 5016.043696127976]}, {"id": "ed2c5269-91ac-4f93-b68d-6b546cef20d8", "name": "vae_name", "type": "COMBO", "linkIds": [45], "pos": [127.60368520069494, 5036.043696127976]}, {"id": "560ba519-ec0c-4ca4-b8f0-f02174012475", "name": "name", "type": "COMBO", "linkIds": [46], "pos": [127.60368520069494, 5056.043696127976]}], "outputs": [{"id": "47f9a22d-6619-4917-9447-a7d5d08dceb5", "name": "IMAGE", "type": "IMAGE", "linkIds": [35], "pos": [1618.6038576146689, 4956.043696127976]}], "widgets": [], "nodes": [{"id": 14, "type": "CLIPLoader", "pos": [340, 4820], "size": [269.9609375, 106], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 44}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [33]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPLoader", "models": [{"name": "qwen_3_4b.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/text_encoders/qwen_3_4b.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_3_4b.safetensors", "lumina2", "default"]}, {"id": 15, "type": "UNETLoader", "pos": [340, 4670], "size": [269.9609375, 82], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 43}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [28]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "UNETLoader", "models": [{"name": "z_image_turbo_bf16.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/diffusion_models/z_image_turbo_bf16.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["z_image_turbo_bf16.safetensors", "default"]}, {"id": 16, "type": "VAELoader", "pos": [340, 5000], "size": [269.9609375, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 45}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [21, 30]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "VAELoader", "models": [{"name": "ae.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/vae/ae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ae.safetensors"]}, {"id": 17, "type": "ModelPatchLoader", "pos": [340, 5130], "size": [269.9609375, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "name", "name": "name", "type": "COMBO", "widget": {"name": "name"}, "link": 46}], "outputs": [{"localized_name": "MODEL_PATCH", "name": "MODEL_PATCH", "type": "MODEL_PATCH", "links": [29]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.51", "Node name for S&R": "ModelPatchLoader", "models": [{"name": "Z-Image-Turbo-Fun-Controlnet-Union.safetensors", "url": "https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union/resolve/main/Z-Image-Turbo-Fun-Controlnet-Union.safetensors", "directory": "model_patches"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["Z-Image-Turbo-Fun-Controlnet-Union.safetensors"]}, {"id": 18, "type": "ModelSamplingAuraFlow", "pos": [1110, 4610], "size": [289.97395833333337, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 22}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [23]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "ModelSamplingAuraFlow", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3]}, {"id": 19, "type": "KSampler", "pos": [1110, 4720], "size": [300, 309.9609375], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 23}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 24}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 25}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 26}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [20]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "KSampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize", 9, 1, "res_multistep", "simple", 1]}, {"id": 20, "type": "ConditioningZeroOut", "pos": [860, 5160], "size": [204.134765625, 26], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 27}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [25]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "ConditioningZeroOut", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 21, "type": "QwenImageDiffsynthControlnet", "pos": [720, 5320], "size": [289.97395833333337, 138], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 28}, {"localized_name": "model_patch", "name": "model_patch", "type": "MODEL_PATCH", "link": 29}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 30}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 42}, {"localized_name": "mask", "name": "mask", "shape": 7, "type": "MASK", "link": null}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [22]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.76", "Node name for S&R": "QwenImageDiffsynthControlnet", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1]}, {"id": 23, "type": "CLIPTextEncode", "pos": [660, 4660], "size": [400, 179.9609375], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 33}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 37}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [24, 27]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 24, "type": "VAEDecode", "pos": [1450, 4620], "size": [200, 46], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 20}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 21}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [35]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 25, "type": "GetImageSize", "pos": [330, 5540], "size": [140, 66], "flags": {"collapsed": false}, "order": 11, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 41}], "outputs": [{"localized_name": "width", "name": "width", "type": "INT", "links": [31]}, {"localized_name": "height", "name": "height", "type": "INT", "links": [32]}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.76", "Node name for S&R": "GetImageSize", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 22, "type": "EmptySD3LatentImage", "pos": [1110, 5540], "size": [259.9609375, 106], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 31}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 32}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [26]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "EmptySD3LatentImage", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1024, 1024, 1]}], "groups": [{"id": 1, "title": "Prompt", "bounding": [640, 4590, 440, 630], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Models", "bounding": [320, 4590, 300, 640], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Apple ControlNet", "bounding": [640, 5240, 440, 260], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 20, "origin_id": 19, "origin_slot": 0, "target_id": 24, "target_slot": 0, "type": "LATENT"}, {"id": 21, "origin_id": 16, "origin_slot": 0, "target_id": 24, "target_slot": 1, "type": "VAE"}, {"id": 22, "origin_id": 21, "origin_slot": 0, "target_id": 18, "target_slot": 0, "type": "MODEL"}, {"id": 23, "origin_id": 18, "origin_slot": 0, "target_id": 19, "target_slot": 0, "type": "MODEL"}, {"id": 24, "origin_id": 23, "origin_slot": 0, "target_id": 19, "target_slot": 1, "type": "CONDITIONING"}, {"id": 25, "origin_id": 20, "origin_slot": 0, "target_id": 19, "target_slot": 2, "type": "CONDITIONING"}, {"id": 26, "origin_id": 22, "origin_slot": 0, "target_id": 19, "target_slot": 3, "type": "LATENT"}, {"id": 27, "origin_id": 23, "origin_slot": 0, "target_id": 20, "target_slot": 0, "type": "CONDITIONING"}, {"id": 28, "origin_id": 15, "origin_slot": 0, "target_id": 21, "target_slot": 0, "type": "MODEL"}, {"id": 29, "origin_id": 17, "origin_slot": 0, "target_id": 21, "target_slot": 1, "type": "MODEL_PATCH"}, {"id": 30, "origin_id": 16, "origin_slot": 0, "target_id": 21, "target_slot": 2, "type": "VAE"}, {"id": 31, "origin_id": 25, "origin_slot": 0, "target_id": 22, "target_slot": 0, "type": "INT"}, {"id": 32, "origin_id": 25, "origin_slot": 1, "target_id": 22, "target_slot": 1, "type": "INT"}, {"id": 33, "origin_id": 14, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "CLIP"}, {"id": 35, "origin_id": 24, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 37, "origin_id": -10, "origin_slot": 1, "target_id": 23, "target_slot": 1, "type": "STRING"}, {"id": 41, "origin_id": -10, "origin_slot": 0, "target_id": 25, "target_slot": 0, "type": "IMAGE"}, {"id": 42, "origin_id": -10, "origin_slot": 0, "target_id": 21, "target_slot": 3, "type": "IMAGE"}, {"id": 43, "origin_id": -10, "origin_slot": 2, "target_id": 15, "target_slot": 0, "type": "COMBO"}, {"id": 44, "origin_id": -10, "origin_slot": 3, "target_id": 14, "target_slot": 0, "type": "COMBO"}, {"id": 45, "origin_id": -10, "origin_slot": 4, "target_id": 16, "target_slot": 0, "type": "COMBO"}, {"id": 46, "origin_id": -10, "origin_slot": 5, "target_id": 17, "target_slot": 0, "type": "COMBO"}], "extra": {"frontendVersion": "1.37.10", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true}, "category": "Image generation and editing/Pose to image"}]}, "config": {}, "extra": {"frontendVersion": "1.37.10", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true, "ds": {"scale": 0.6479518372239997, "offset": [852.9773200429215, -3036.34291480022]}}, "version": 0.4} diff --git a/blueprints/Pose to Video (LTX 2.0).json b/blueprints/Pose to Video (LTX 2.0).json new file mode 100644 index 000000000..78c098798 --- /dev/null +++ b/blueprints/Pose to Video (LTX 2.0).json @@ -0,0 +1 @@ +{"id": "01cd475b-52df-43bf-aafa-484a5976d2d2", "revision": 0, "last_node_id": 160, "last_link_id": 410, "nodes": [{"id": 1, "type": "f0e58a6b-7246-4103-9fec-73b423634b1f", "pos": [210, 3830], "size": [420, 500], "flags": {"collapsed": false}, "order": 0, "mode": 0, "inputs": [{"label": "prompt", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"label": "first_frame_strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}, {"label": "disable_first_frame", "name": "bypass", "type": "BOOLEAN", "widget": {"name": "bypass"}, "link": null}, {"label": "first frame", "name": "image", "type": "IMAGE", "link": null}, {"label": "control image", "name": "input", "type": "IMAGE,MASK", "link": null}, {"name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": null}, {"name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": null}, {"label": "distll_lora", "name": "lora_name_1", "type": "COMBO", "widget": {"name": "lora_name_1"}, "link": null}, {"label": "upscale_model", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}, {"name": "resize_type.width", "type": "INT", "widget": {"name": "resize_type.width"}, "link": null}, {"name": "resize_type.height", "type": "INT", "widget": {"name": "resize_type.height"}, "link": null}, {"name": "length", "type": "INT", "widget": {"name": "length"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "resize_type.width"], ["-1", "resize_type.height"], ["-1", "length"], ["-1", "strength"], ["-1", "bypass"], ["126", "noise_seed"], ["126", "control_after_generate"], ["-1", "ckpt_name"], ["-1", "lora_name"], ["-1", "model_name"], ["-1", "lora_name_1"]], "cnr_id": "comfy-core", "ver": "0.7.0", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["", 1280, 720, 97, 1, false, null, null, "ltx-2-19b-dev-fp8.safetensors", "ltx-2-19b-ic-lora-pose-control.safetensors", "ltx-2-spatial-upscaler-x2-1.0.safetensors", "ltx-2-19b-distilled-lora-384.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "f0e58a6b-7246-4103-9fec-73b423634b1f", "version": 1, "state": {"lastGroupId": 11, "lastNodeId": 160, "lastLinkId": 410, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Pose to Video (LTX 2.0)", "inputNode": {"id": -10, "bounding": [-2220, 4180, 153.3203125, 280]}, "outputNode": {"id": -20, "bounding": [1750.2777777777776, 4091.1111111111113, 120, 60]}, "inputs": [{"id": "0f1d2f96-933a-4a7b-8f1a-7b49fc4ade09", "name": "text", "type": "STRING", "linkIds": [345], "label": "prompt", "pos": [-2086.6796875, 4200]}, {"id": "59430efe-1090-4e36-8afe-b21ce7f4268b", "name": "strength", "type": "FLOAT", "linkIds": [370, 371], "label": "first_frame_strength", "pos": [-2086.6796875, 4220]}, {"id": "6145a9b9-68ed-4956-89f7-7a5ebdd5c99e", "name": "bypass", "type": "BOOLEAN", "linkIds": [363, 368], "label": "disable_first_frame", "pos": [-2086.6796875, 4240]}, {"id": "f7aa8c12-bdba-4bbd-84cf-b49cfc32a1dd", "name": "image", "type": "IMAGE", "linkIds": [398, 399], "label": "first frame", "pos": [-2086.6796875, 4260]}, {"id": "da40a4c0-cd19-46c6-8eb3-62d0026fbe85", "name": "input", "type": "IMAGE,MASK", "linkIds": [400], "label": "control image", "pos": [-2086.6796875, 4280]}, {"id": "8005344b-99d6-4829-a619-c4e8ef640eb9", "name": "ckpt_name", "type": "COMBO", "linkIds": [401, 402, 403], "pos": [-2086.6796875, 4300]}, {"id": "25e7c4e8-850c-4f37-bc14-e3f4b5f228c0", "name": "lora_name", "type": "COMBO", "linkIds": [404, 405], "pos": [-2086.6796875, 4320]}, {"id": "f16a18dd-947e-400a-8889-02cf998f760a", "name": "lora_name_1", "type": "COMBO", "linkIds": [406], "label": "distll_lora", "pos": [-2086.6796875, 4340]}, {"id": "1abf156c-4c85-4ee5-8671-62df3177d835", "name": "model_name", "type": "COMBO", "linkIds": [407], "label": "upscale_model", "pos": [-2086.6796875, 4360]}, {"id": "203402cf-4253-4daf-bf78-5def9496e0af", "name": "resize_type.width", "type": "INT", "linkIds": [408], "pos": [-2086.6796875, 4380]}, {"id": "e6d8ac4a-34d4-46c6-bcb2-4e66a696438c", "name": "resize_type.height", "type": "INT", "linkIds": [409], "pos": [-2086.6796875, 4400]}, {"id": "6aa6cf2c-bc4f-4f8b-be62-aa15793375dc", "name": "length", "type": "INT", "linkIds": [410], "pos": [-2086.6796875, 4420]}], "outputs": [{"id": "4e837941-de2d-4df8-8f94-686e24036897", "name": "VIDEO", "type": "VIDEO", "linkIds": [304], "localized_name": "VIDEO", "pos": [1770.2777777777776, 4111.111111111111]}], "widgets": [], "nodes": [{"id": 93, "type": "CFGGuider", "pos": [-697.721823660531, 3671.1105325465196], "size": [269.97395833333337, 98], "flags": {}, "order": 16, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 326}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 309}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 311}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "links": [261]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "CFGGuider", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3]}, {"id": 94, "type": "KSamplerSelect", "pos": [-697.721823660531, 3841.1107362825187], "size": [269.97395833333337, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "links": [262]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "KSamplerSelect", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["euler"]}, {"id": 99, "type": "ManualSigmas", "pos": [410.27824286284044, 3851.110970278795], "size": [269.97395833333337, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "sigmas", "name": "sigmas", "type": "STRING", "widget": {"name": "sigmas"}, "link": null}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "links": [278]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "ManualSigmas", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["0.909375, 0.725, 0.421875, 0.0"]}, {"id": 100, "type": "LatentUpscaleModelLoader", "pos": [-69.72208571196083, 3701.1104657166875], "size": [389.97395833333337, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 407}], "outputs": [{"localized_name": "LATENT_UPSCALE_MODEL", "name": "LATENT_UPSCALE_MODEL", "type": "LATENT_UPSCALE_MODEL", "links": [288]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LatentUpscaleModelLoader", "models": [{"name": "ltx-2-spatial-upscaler-x2-1.0.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-spatial-upscaler-x2-1.0.safetensors", "directory": "latent_upscale_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-spatial-upscaler-x2-1.0.safetensors"]}, {"id": 101, "type": "LTXVConcatAVLatent", "pos": [410.27824286284044, 4101.110949206838], "size": [269.97395833333337, 46], "flags": {}, "order": 18, "mode": 0, "inputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "link": 365}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "link": 266}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [279]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "LTXVConcatAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 108, "type": "CFGGuider", "pos": [410.27824286284044, 3701.1104657166875], "size": [269.97395833333337, 98], "flags": {}, "order": 22, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 280}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 281}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 282}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}], "outputs": [{"localized_name": "GUIDER", "name": "GUIDER", "type": "GUIDER", "links": [276]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.71", "Node name for S&R": "CFGGuider", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1]}, {"id": 123, "type": "SamplerCustomAdvanced", "pos": [-387.72197839215096, 3521.1103425011374], "size": [213.09895833333334, 106], "flags": {}, "order": 31, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 260}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 261}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 262}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 263}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 323}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "links": [272]}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.60", "Node name for S&R": "SamplerCustomAdvanced", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 114, "type": "LTXVConditioning", "pos": [-1133.7215420073496, 4141.110347554622], "size": [269.97395833333337, 78], "flags": {}, "order": 27, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 292}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 293}, {"localized_name": "frame_rate", "name": "frame_rate", "type": "FLOAT", "widget": {"name": "frame_rate"}, "link": 355}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [313]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [314]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "LTXVConditioning", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [25]}, {"id": 119, "type": "CLIPTextEncode", "pos": [-1163.7218246405453, 3881.1109034489627], "size": [400, 88], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 294}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [293]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["blurry, low quality, still frame, frames, watermark, overlay, titles, has blurbox, has subtitles"], "color": "#323", "bgcolor": "#535"}, {"id": 116, "type": "LTXVConcatAVLatent", "pos": [-519.7217122979332, 4701.110031965835], "size": [187.5, 46], "flags": {}, "order": 29, "mode": 0, "inputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "link": 324}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "link": 300}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [322, 323]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVConcatAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 122, "type": "LTXVSeparateAVLatent", "pos": [-393.72183921949465, 3801.1107787938904], "size": [239.97395833333334, 46], "flags": {}, "order": 30, "mode": 0, "inputs": [{"localized_name": "av_latent", "name": "av_latent", "type": "LATENT", "link": 272}], "outputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "links": [270]}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "links": [266]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "LTXVSeparateAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 124, "type": "CLIPTextEncode", "pos": [-1174.7214530029996, 3515.1112854387566], "size": [409.97395833333337, 88], "flags": {}, "order": 32, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 295}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 345}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [292]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 98, "type": "KSamplerSelect", "pos": [410.27824286284044, 3981.1101681370833], "size": [269.97395833333337, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}], "outputs": [{"localized_name": "SAMPLER", "name": "SAMPLER", "type": "SAMPLER", "links": [277]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "KSamplerSelect", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["gradient_estimation"]}, {"id": 105, "type": "LoraLoaderModelOnly", "pos": [-69.72208571196083, 3571.110499039739], "size": [389.97395833333337, 82], "flags": {}, "order": 15, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 327}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 406}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [280]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "ltx-2-19b-distilled-lora-384.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-distilled-lora-384.safetensors", "directory": "loras"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-distilled-lora-384.safetensors", 1]}, {"id": 95, "type": "LTXVScheduler", "pos": [-699.7218704597861, 3981.1101681370833], "size": [269.97395833333337, 154], "flags": {}, "order": 17, "mode": 0, "inputs": [{"localized_name": "latent", "name": "latent", "shape": 7, "type": "LATENT", "link": 322}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "max_shift", "name": "max_shift", "type": "FLOAT", "widget": {"name": "max_shift"}, "link": null}, {"localized_name": "base_shift", "name": "base_shift", "type": "FLOAT", "widget": {"name": "base_shift"}, "link": null}, {"localized_name": "stretch", "name": "stretch", "type": "BOOLEAN", "widget": {"name": "stretch"}, "link": null}, {"localized_name": "terminal", "name": "terminal", "type": "FLOAT", "widget": {"name": "terminal"}, "link": null}], "outputs": [{"localized_name": "SIGMAS", "name": "SIGMAS", "type": "SIGMAS", "links": [263]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "LTXVScheduler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [20, 2.05, 0.95, true, 0.1]}, {"id": 126, "type": "RandomNoise", "pos": [-697.721823660531, 3521.1103425011374], "size": [269.97395833333337, 82], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "links": [260]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "RandomNoise", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize"]}, {"id": 107, "type": "SamplerCustomAdvanced", "pos": [710.2782734905775, 3571.110499039739], "size": [212.36979166666669, 106], "flags": {}, "order": 21, "mode": 0, "inputs": [{"localized_name": "noise", "name": "noise", "type": "NOISE", "link": 347}, {"localized_name": "guider", "name": "guider", "type": "GUIDER", "link": 276}, {"localized_name": "sampler", "name": "sampler", "type": "SAMPLER", "link": 277}, {"localized_name": "sigmas", "name": "sigmas", "type": "SIGMAS", "link": 278}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 279}], "outputs": [{"localized_name": "output", "name": "output", "type": "LATENT", "links": []}, {"localized_name": "denoised_output", "name": "denoised_output", "type": "LATENT", "links": [336]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "SamplerCustomAdvanced", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 143, "type": "RandomNoise", "pos": [410.27824286284044, 3571.110499039739], "size": [269.97395833333337, 82], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}], "outputs": [{"localized_name": "NOISE", "name": "NOISE", "type": "NOISE", "links": [347]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "RandomNoise", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "fixed"]}, {"id": 139, "type": "LTXVAudioVAEDecode", "pos": [1130.2783163694094, 3841.1107362825187], "size": [239.97395833333334, 46], "flags": {}, "order": 35, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 338}, {"label": "Audio VAE", "localized_name": "audio_vae", "name": "audio_vae", "type": "VAE", "link": 383}], "outputs": [{"localized_name": "Audio", "name": "Audio", "type": "AUDIO", "links": [339]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVAudioVAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 106, "type": "CreateVideo", "pos": [1420.2783925712918, 3761.1104019496292], "size": [269.97395833333337, 78], "flags": {}, "order": 20, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 352}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 339}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 356}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [304]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "CreateVideo", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [25]}, {"id": 134, "type": "LoraLoaderModelOnly", "pos": [-1649.721454901846, 3761.1104019496292], "size": [419.97395833333337, 82], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 325}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 404}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [326, 327]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "ltx-2-19b-ic-lora-pose-control.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2-19b-IC-LoRA-Pose-Control/resolve/main/ltx-2-19b-ic-lora-pose-control.safetensors", "directory": "loras"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-ic-lora-pose-control.safetensors", 1], "color": "#322", "bgcolor": "#533"}, {"id": 138, "type": "LTXVSeparateAVLatent", "pos": [730.2784619127078, 3731.1109580277], "size": [193.2916015625, 46], "flags": {}, "order": 34, "mode": 0, "inputs": [{"localized_name": "av_latent", "name": "av_latent", "type": "LATENT", "link": 336}], "outputs": [{"localized_name": "video_latent", "name": "video_latent", "type": "LATENT", "links": [337, 351]}, {"localized_name": "audio_latent", "name": "audio_latent", "type": "LATENT", "links": [338]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "LTXVSeparateAVLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 144, "type": "VAEDecodeTiled", "pos": [1120.2783619435547, 3641.110599376351], "size": [269.97395833333337, 150], "flags": {}, "order": 36, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 351}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 353}, {"localized_name": "tile_size", "name": "tile_size", "type": "INT", "widget": {"name": "tile_size"}, "link": null}, {"localized_name": "overlap", "name": "overlap", "type": "INT", "widget": {"name": "overlap"}, "link": null}, {"localized_name": "temporal_size", "name": "temporal_size", "type": "INT", "widget": {"name": "temporal_size"}, "link": null}, {"localized_name": "temporal_overlap", "name": "temporal_overlap", "type": "INT", "widget": {"name": "temporal_overlap"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [352]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "VAEDecodeTiled", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [512, 64, 4096, 8]}, {"id": 113, "type": "VAEDecode", "pos": [1130.2783163694094, 3531.1113453160738], "size": [239.97395833333334, 46], "flags": {}, "order": 26, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 337}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 291}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 145, "type": "PrimitiveInt", "pos": [-1600, 4940], "size": [269.97395833333337, 82], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "value", "name": "value", "type": "INT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "INT", "name": "INT", "type": "INT", "links": [354]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "PrimitiveInt", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [24, "fixed"]}, {"id": 148, "type": "PrimitiveFloat", "pos": [-1600, 5070], "size": [269.97395833333337, 58], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [355, 356]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "PrimitiveFloat", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [24]}, {"id": 118, "type": "Reroute", "pos": [-229.7217758812614, 4211.111007032079], "size": [75, 26], "flags": {}, "order": 14, "mode": 0, "inputs": [{"name": "", "type": "*", "link": 303}], "outputs": [{"name": "", "type": "VAE", "links": [289, 291, 367]}], "properties": {"showOutputText": false, "horizontal": false}}, {"id": 151, "type": "LTXVImgToVideoInplace", "pos": [-19.72161465663438, 4071.1107364662485], "size": [269.97395833333337, 122], "flags": {}, "order": 38, "mode": 0, "inputs": [{"localized_name": "vae", "name": "vae", "type": "VAE", "link": 367}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 398}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 366}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": 371}, {"localized_name": "bypass", "name": "bypass", "type": "BOOLEAN", "widget": {"name": "bypass"}, "link": 368}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [365]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVImgToVideoInplace", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1, false]}, {"id": 104, "type": "LTXVCropGuides", "pos": [-9.721939801202097, 3841.1107362825187], "size": [239.97395833333334, 66], "flags": {}, "order": 19, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 310}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 312}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 270}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [281]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [282]}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "slot_index": 2, "links": [287]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.68", "Node name for S&R": "LTXVCropGuides", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 112, "type": "LTXVLatentUpsampler", "pos": [-9.721939801202097, 3961.111517352274], "size": [259.97395833333337, 66], "flags": {}, "order": 25, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 287}, {"localized_name": "upscale_model", "name": "upscale_model", "type": "LATENT_UPSCALE_MODEL", "link": 288}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 289}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [366]}], "title": "spatial", "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVLatentUpsampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 97, "type": "LTXAVTextEncoderLoader", "pos": [-1649.721454901846, 4041.1110828665023], "size": [419.97395833333337, 106], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "text_encoder", "name": "text_encoder", "type": "COMBO", "widget": {"name": "text_encoder"}, "link": 405}, {"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 403}, {"localized_name": "device", "name": "device", "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [294, 295]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXAVTextEncoderLoader", "models": [{"name": "ltx-2-19b-dev-fp8.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-dev-fp8.safetensors", "directory": "checkpoints"}, {"name": "gemma_3_12B_it_fp4_mixed.safetensors", "url": "https://huggingface.co/Comfy-Org/ltx-2/resolve/main/split_files/text_encoders/gemma_3_12B_it_fp4_mixed.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-ic-lora-pose-control.safetensors", "ltx-2-19b-dev-fp8.safetensors", "default"]}, {"id": 103, "type": "CheckpointLoaderSimple", "pos": [-1649.721454901846, 3591.1104777840524], "size": [419.97395833333337, 98], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 401}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [325]}, {"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": []}, {"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [303, 328, 353, 359]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.56", "Node name for S&R": "CheckpointLoaderSimple", "models": [{"name": "ltx-2-19b-dev-fp8.safetensors", "url": "https://huggingface.co/Lightricks/LTX-2/resolve/main/ltx-2-19b-dev-fp8.safetensors", "directory": "checkpoints"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ltx-2-19b-dev-fp8.safetensors"]}, {"id": 156, "type": "LTXVAudioVAELoader", "pos": [-1636.9543279290153, 3911.095334870057], "size": [399.0494791666667, 58], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "ckpt_name", "name": "ckpt_name", "type": "COMBO", "widget": {"name": "ckpt_name"}, "link": 402}], "outputs": [{"localized_name": "Audio VAE", "name": "Audio VAE", "type": "VAE", "links": [382, 383]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.0", "Node name for S&R": "LTXVAudioVAELoader"}, "widgets_values": ["ltx-2-19b-dev-fp8.safetensors"]}, {"id": 149, "type": "LTXVImgToVideoInplace", "pos": [-1089.7215608128167, 4401.110560478942], "size": [269.97395833333337, 122], "flags": {}, "order": 37, "mode": 0, "inputs": [{"localized_name": "vae", "name": "vae", "type": "VAE", "link": 359}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 399}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 360}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": 370}, {"localized_name": "bypass", "name": "bypass", "type": "BOOLEAN", "widget": {"name": "bypass"}, "link": 363}], "outputs": [{"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [357]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "LTXVImgToVideoInplace", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1, false]}, {"id": 132, "type": "LTXVAddGuide", "pos": [-599.7217670603999, 4421.110609115862], "size": [269.97395833333337, 162], "flags": {}, "order": 33, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 313}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 314}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 328}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "link": 357}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 395}, {"localized_name": "frame_idx", "name": "frame_idx", "type": "INT", "widget": {"name": "frame_idx"}, "link": null}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [309, 310]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [311, 312]}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [324]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.75", "Node name for S&R": "LTXVAddGuide", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, 1]}, {"id": 154, "type": "MarkdownNote", "pos": [-1630, 5190], "size": [350, 88], "flags": {"collapsed": false}, "order": 11, "mode": 0, "inputs": [], "outputs": [], "title": "Frame Rate Note", "properties": {}, "widgets_values": ["Please make sure the frame rate value is the same in both boxes"], "color": "#432", "bgcolor": "#653"}, {"id": 159, "type": "ResizeImageMaskNode", "pos": [-1610, 4580], "size": [284.375, 154], "flags": {}, "order": 39, "mode": 0, "inputs": [{"localized_name": "input", "name": "input", "type": "IMAGE,MASK", "link": 400}, {"localized_name": "resize_type", "name": "resize_type", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "resize_type"}, "link": null}, {"localized_name": "width", "name": "resize_type.width", "type": "INT", "widget": {"name": "resize_type.width"}, "link": 408}, {"localized_name": "height", "name": "resize_type.height", "type": "INT", "widget": {"name": "resize_type.height"}, "link": 409}, {"localized_name": "crop", "name": "resize_type.crop", "type": "COMBO", "widget": {"name": "resize_type.crop"}, "link": null}, {"localized_name": "scale_method", "name": "scale_method", "type": "COMBO", "widget": {"name": "scale_method"}, "link": null}], "outputs": [{"localized_name": "resized", "name": "resized", "type": "IMAGE,MASK", "links": [391, 392, 395]}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "ResizeImageMaskNode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["scale dimensions", 1280, 720, "center", "lanczos"]}, {"id": 110, "type": "GetImageSize", "pos": [-1600, 4780], "size": [259.97395833333337, 66], "flags": {}, "order": 23, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 391}], "outputs": [{"localized_name": "width", "name": "width", "type": "INT", "links": [296]}, {"localized_name": "height", "name": "height", "type": "INT", "links": [297]}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "links": []}], "properties": {"cnr_id": "comfy-core", "ver": "0.7.0", "Node name for S&R": "GetImageSize", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 115, "type": "EmptyLTXVLatentVideo", "pos": [-1099.721794809093, 4611.11072170357], "size": [269.97395833333337, 130], "flags": {}, "order": 28, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 296}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 297}, {"localized_name": "length", "name": "length", "type": "INT", "widget": {"name": "length"}, "link": 410}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [360]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.60", "Node name for S&R": "EmptyLTXVLatentVideo", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [768, 512, 97, 1]}, {"id": 111, "type": "LTXVEmptyLatentAudio", "pos": [-1099.721794809093, 4811.110229576288], "size": [269.97395833333337, 106], "flags": {}, "order": 24, "mode": 0, "inputs": [{"localized_name": "audio_vae", "name": "audio_vae", "type": "VAE", "link": 382}, {"localized_name": "frames_number", "name": "frames_number", "type": "INT", "widget": {"name": "frames_number"}, "link": null}, {"localized_name": "frame_rate", "name": "frame_rate", "type": "INT", "widget": {"name": "frame_rate"}, "link": 354}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "Latent", "name": "Latent", "type": "LATENT", "links": [300]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.68", "Node name for S&R": "LTXVEmptyLatentAudio", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [97, 25, 1]}], "groups": [{"id": 1, "title": "Model", "bounding": [-1660, 3440, 440, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Basic Sampling", "bounding": [-700, 3440, 570, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Prompt", "bounding": [-1180, 3440, 440, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 5, "title": "Latent", "bounding": [-1180, 4290, 1050, 680], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 9, "title": "Upscale Sampling(2x)", "bounding": [-100, 3440, 1090, 820], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 6, "title": "Sampler", "bounding": [350, 3480, 620, 750], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 7, "title": "Model", "bounding": [-90, 3480, 430, 310], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 11, "title": "Frame rate", "bounding": [-1610, 4860, 290, 271.6], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 326, "origin_id": 134, "origin_slot": 0, "target_id": 93, "target_slot": 0, "type": "MODEL"}, {"id": 309, "origin_id": 132, "origin_slot": 0, "target_id": 93, "target_slot": 1, "type": "CONDITIONING"}, {"id": 311, "origin_id": 132, "origin_slot": 1, "target_id": 93, "target_slot": 2, "type": "CONDITIONING"}, {"id": 266, "origin_id": 122, "origin_slot": 1, "target_id": 101, "target_slot": 1, "type": "LATENT"}, {"id": 280, "origin_id": 105, "origin_slot": 0, "target_id": 108, "target_slot": 0, "type": "MODEL"}, {"id": 281, "origin_id": 104, "origin_slot": 0, "target_id": 108, "target_slot": 1, "type": "CONDITIONING"}, {"id": 282, "origin_id": 104, "origin_slot": 1, "target_id": 108, "target_slot": 2, "type": "CONDITIONING"}, {"id": 260, "origin_id": 126, "origin_slot": 0, "target_id": 123, "target_slot": 0, "type": "NOISE"}, {"id": 261, "origin_id": 93, "origin_slot": 0, "target_id": 123, "target_slot": 1, "type": "GUIDER"}, {"id": 262, "origin_id": 94, "origin_slot": 0, "target_id": 123, "target_slot": 2, "type": "SAMPLER"}, {"id": 263, "origin_id": 95, "origin_slot": 0, "target_id": 123, "target_slot": 3, "type": "SIGMAS"}, {"id": 323, "origin_id": 116, "origin_slot": 0, "target_id": 123, "target_slot": 4, "type": "LATENT"}, {"id": 296, "origin_id": 110, "origin_slot": 0, "target_id": 115, "target_slot": 0, "type": "INT"}, {"id": 297, "origin_id": 110, "origin_slot": 1, "target_id": 115, "target_slot": 1, "type": "INT"}, {"id": 325, "origin_id": 103, "origin_slot": 0, "target_id": 134, "target_slot": 0, "type": "MODEL"}, {"id": 292, "origin_id": 124, "origin_slot": 0, "target_id": 114, "target_slot": 0, "type": "CONDITIONING"}, {"id": 293, "origin_id": 119, "origin_slot": 0, "target_id": 114, "target_slot": 1, "type": "CONDITIONING"}, {"id": 294, "origin_id": 97, "origin_slot": 0, "target_id": 119, "target_slot": 0, "type": "CLIP"}, {"id": 324, "origin_id": 132, "origin_slot": 2, "target_id": 116, "target_slot": 0, "type": "LATENT"}, {"id": 300, "origin_id": 111, "origin_slot": 0, "target_id": 116, "target_slot": 1, "type": "LATENT"}, {"id": 313, "origin_id": 114, "origin_slot": 0, "target_id": 132, "target_slot": 0, "type": "CONDITIONING"}, {"id": 314, "origin_id": 114, "origin_slot": 1, "target_id": 132, "target_slot": 1, "type": "CONDITIONING"}, {"id": 328, "origin_id": 103, "origin_slot": 2, "target_id": 132, "target_slot": 2, "type": "VAE"}, {"id": 272, "origin_id": 123, "origin_slot": 0, "target_id": 122, "target_slot": 0, "type": "LATENT"}, {"id": 336, "origin_id": 107, "origin_slot": 1, "target_id": 138, "target_slot": 0, "type": "LATENT"}, {"id": 339, "origin_id": 139, "origin_slot": 0, "target_id": 106, "target_slot": 1, "type": "AUDIO"}, {"id": 295, "origin_id": 97, "origin_slot": 0, "target_id": 124, "target_slot": 0, "type": "CLIP"}, {"id": 303, "origin_id": 103, "origin_slot": 2, "target_id": 118, "target_slot": 0, "type": "VAE"}, {"id": 338, "origin_id": 138, "origin_slot": 1, "target_id": 139, "target_slot": 0, "type": "LATENT"}, {"id": 337, "origin_id": 138, "origin_slot": 0, "target_id": 113, "target_slot": 0, "type": "LATENT"}, {"id": 291, "origin_id": 118, "origin_slot": 0, "target_id": 113, "target_slot": 1, "type": "VAE"}, {"id": 276, "origin_id": 108, "origin_slot": 0, "target_id": 107, "target_slot": 1, "type": "GUIDER"}, {"id": 277, "origin_id": 98, "origin_slot": 0, "target_id": 107, "target_slot": 2, "type": "SAMPLER"}, {"id": 278, "origin_id": 99, "origin_slot": 0, "target_id": 107, "target_slot": 3, "type": "SIGMAS"}, {"id": 279, "origin_id": 101, "origin_slot": 0, "target_id": 107, "target_slot": 4, "type": "LATENT"}, {"id": 327, "origin_id": 134, "origin_slot": 0, "target_id": 105, "target_slot": 0, "type": "MODEL"}, {"id": 310, "origin_id": 132, "origin_slot": 0, "target_id": 104, "target_slot": 0, "type": "CONDITIONING"}, {"id": 312, "origin_id": 132, "origin_slot": 1, "target_id": 104, "target_slot": 1, "type": "CONDITIONING"}, {"id": 270, "origin_id": 122, "origin_slot": 0, "target_id": 104, "target_slot": 2, "type": "LATENT"}, {"id": 287, "origin_id": 104, "origin_slot": 2, "target_id": 112, "target_slot": 0, "type": "LATENT"}, {"id": 288, "origin_id": 100, "origin_slot": 0, "target_id": 112, "target_slot": 1, "type": "LATENT_UPSCALE_MODEL"}, {"id": 289, "origin_id": 118, "origin_slot": 0, "target_id": 112, "target_slot": 2, "type": "VAE"}, {"id": 322, "origin_id": 116, "origin_slot": 0, "target_id": 95, "target_slot": 0, "type": "LATENT"}, {"id": 304, "origin_id": 106, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 345, "origin_id": -10, "origin_slot": 0, "target_id": 124, "target_slot": 1, "type": "STRING"}, {"id": 347, "origin_id": 143, "origin_slot": 0, "target_id": 107, "target_slot": 0, "type": "NOISE"}, {"id": 351, "origin_id": 138, "origin_slot": 0, "target_id": 144, "target_slot": 0, "type": "LATENT"}, {"id": 352, "origin_id": 144, "origin_slot": 0, "target_id": 106, "target_slot": 0, "type": "IMAGE"}, {"id": 353, "origin_id": 103, "origin_slot": 2, "target_id": 144, "target_slot": 1, "type": "VAE"}, {"id": 354, "origin_id": 145, "origin_slot": 0, "target_id": 111, "target_slot": 2, "type": "INT"}, {"id": 355, "origin_id": 148, "origin_slot": 0, "target_id": 114, "target_slot": 2, "type": "FLOAT"}, {"id": 356, "origin_id": 148, "origin_slot": 0, "target_id": 106, "target_slot": 2, "type": "FLOAT"}, {"id": 357, "origin_id": 149, "origin_slot": 0, "target_id": 132, "target_slot": 3, "type": "LATENT"}, {"id": 359, "origin_id": 103, "origin_slot": 2, "target_id": 149, "target_slot": 0, "type": "VAE"}, {"id": 360, "origin_id": 115, "origin_slot": 0, "target_id": 149, "target_slot": 2, "type": "LATENT"}, {"id": 363, "origin_id": -10, "origin_slot": 2, "target_id": 149, "target_slot": 4, "type": "BOOLEAN"}, {"id": 365, "origin_id": 151, "origin_slot": 0, "target_id": 101, "target_slot": 0, "type": "LATENT"}, {"id": 366, "origin_id": 112, "origin_slot": 0, "target_id": 151, "target_slot": 2, "type": "LATENT"}, {"id": 367, "origin_id": 118, "origin_slot": 0, "target_id": 151, "target_slot": 0, "type": "VAE"}, {"id": 368, "origin_id": -10, "origin_slot": 2, "target_id": 151, "target_slot": 4, "type": "BOOLEAN"}, {"id": 370, "origin_id": -10, "origin_slot": 1, "target_id": 149, "target_slot": 3, "type": "FLOAT"}, {"id": 371, "origin_id": -10, "origin_slot": 1, "target_id": 151, "target_slot": 3, "type": "FLOAT"}, {"id": 382, "origin_id": 156, "origin_slot": 0, "target_id": 111, "target_slot": 0, "type": "VAE"}, {"id": 383, "origin_id": 156, "origin_slot": 0, "target_id": 139, "target_slot": 1, "type": "VAE"}, {"id": 391, "origin_id": 159, "origin_slot": 0, "target_id": 110, "target_slot": 0, "type": "IMAGE"}, {"id": 395, "origin_id": 159, "origin_slot": 0, "target_id": 132, "target_slot": 4, "type": "IMAGE"}, {"id": 398, "origin_id": -10, "origin_slot": 3, "target_id": 151, "target_slot": 1, "type": "IMAGE"}, {"id": 399, "origin_id": -10, "origin_slot": 3, "target_id": 149, "target_slot": 1, "type": "IMAGE"}, {"id": 400, "origin_id": -10, "origin_slot": 4, "target_id": 159, "target_slot": 0, "type": "IMAGE,MASK"}, {"id": 401, "origin_id": -10, "origin_slot": 5, "target_id": 103, "target_slot": 0, "type": "COMBO"}, {"id": 402, "origin_id": -10, "origin_slot": 5, "target_id": 156, "target_slot": 0, "type": "COMBO"}, {"id": 403, "origin_id": -10, "origin_slot": 5, "target_id": 97, "target_slot": 1, "type": "COMBO"}, {"id": 404, "origin_id": -10, "origin_slot": 6, "target_id": 134, "target_slot": 1, "type": "COMBO"}, {"id": 405, "origin_id": -10, "origin_slot": 6, "target_id": 97, "target_slot": 0, "type": "COMBO"}, {"id": 406, "origin_id": -10, "origin_slot": 7, "target_id": 105, "target_slot": 1, "type": "COMBO"}, {"id": 407, "origin_id": -10, "origin_slot": 8, "target_id": 100, "target_slot": 0, "type": "COMBO"}, {"id": 408, "origin_id": -10, "origin_slot": 9, "target_id": 159, "target_slot": 2, "type": "INT"}, {"id": 409, "origin_id": -10, "origin_slot": 10, "target_id": 159, "target_slot": 3, "type": "INT"}, {"id": 410, "origin_id": -10, "origin_slot": 11, "target_id": 115, "target_slot": 2, "type": "INT"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Pose to video"}]}, "config": {}, "extra": {"ds": {"scale": 1.3889423076923078, "offset": [217.0560747663551, -3703.3333333333335]}, "frontendVersion": "1.37.10", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true}, "version": 0.4} diff --git a/blueprints/Prompt Enhance.json b/blueprints/Prompt Enhance.json new file mode 100644 index 000000000..2612f66db --- /dev/null +++ b/blueprints/Prompt Enhance.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 15, "last_link_id": 0, "nodes": [{"id": 15, "type": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "pos": [-1490, 2040], "size": [400, 260], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"label": "reference images", "name": "images", "type": "IMAGE", "link": null}], "outputs": [{"name": "STRING", "type": "STRING", "links": null}], "title": "Prompt Enhance", "properties": {"proxyWidgets": [["-1", "prompt"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": [""]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "24d8bbfd-39d4-4774-bff0-3de40cc7a471", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 15, "lastLinkId": 14, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Prompt Enhance", "inputNode": {"id": -10, "bounding": [-2170, 2110, 138.876953125, 80]}, "outputNode": {"id": -20, "bounding": [-640, 2110, 120, 60]}, "inputs": [{"id": "aeab7216-00e0-4528-a09b-bba50845c5a6", "name": "prompt", "type": "STRING", "linkIds": [11], "pos": [-2051.123046875, 2130]}, {"id": "7b73fd36-aa31-4771-9066-f6c83879994b", "name": "images", "type": "IMAGE", "linkIds": [14], "label": "reference images", "pos": [-2051.123046875, 2150]}], "outputs": [{"id": "c7b0d930-68a1-48d1-b496-0519e5837064", "name": "STRING", "type": "STRING", "linkIds": [13], "pos": [-620, 2130]}], "widgets": [], "nodes": [{"id": 11, "type": "GeminiNode", "pos": [-1560, 1990], "size": [470, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": 14}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": null}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 11}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.14.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["", "gemini-3-pro-preview", 42, "randomize", "You are an expert in prompt writing.\nBased on the input, rewrite the user's input into a detailed prompt.\nincluding camera settings, lighting, composition, and style.\nReturn the prompt only"], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 11, "origin_id": -10, "origin_slot": 0, "target_id": 11, "target_slot": 4, "type": "STRING"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "STRING"}, {"id": 14, "origin_id": -10, "origin_slot": 1, "target_id": 11, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Prompt enhance"}]}, "extra": {}} diff --git a/blueprints/Sharpen.json b/blueprints/Sharpen.json new file mode 100644 index 000000000..a4accaf59 --- /dev/null +++ b/blueprints/Sharpen.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 25, "last_link_id": 0, "nodes": [{"id": 25, "type": "621ba4e2-22a8-482d-a369-023753198b7b", "pos": [4610, -790], "size": [230, 58], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"label": "IMAGE", "localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Sharpen", "properties": {"proxyWidgets": [["24", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "621ba4e2-22a8-482d-a369-023753198b7b", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 24, "lastLinkId": 36, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Sharpen", "inputNode": {"id": -10, "bounding": [4090, -825, 120, 60]}, "outputNode": {"id": -20, "bounding": [5150, -825, 120, 60]}, "inputs": [{"id": "37011fb7-14b7-4e0e-b1a0-6a02e8da1fd7", "name": "images.image0", "type": "IMAGE", "linkIds": [34], "localized_name": "images.image0", "label": "image", "pos": [4190, -805]}], "outputs": [{"id": "e9182b3f-635c-4cd4-a152-4b4be17ae4b9", "name": "IMAGE0", "type": "IMAGE", "linkIds": [35], "localized_name": "IMAGE0", "label": "IMAGE", "pos": [5170, -805]}], "widgets": [], "nodes": [{"id": 24, "type": "PrimitiveFloat", "pos": [4280, -1240], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "strength", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [36]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [0.5]}, {"id": 23, "type": "GLSLShader", "pos": [4570, -1240], "size": [370, 192], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 34}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 36}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [35]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input"]}], "groups": [], "links": [{"id": 36, "origin_id": 24, "origin_slot": 0, "target_id": 23, "target_slot": 2, "type": "FLOAT"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 23, "target_slot": 0, "type": "IMAGE"}, {"id": 35, "origin_id": 23, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}} diff --git a/blueprints/Text to Audio (ACE-Step 1.5).json b/blueprints/Text to Audio (ACE-Step 1.5).json new file mode 100644 index 000000000..51e3bbed3 --- /dev/null +++ b/blueprints/Text to Audio (ACE-Step 1.5).json @@ -0,0 +1 @@ +{"id": "67979fed-a490-450a-83f4-c7c0105d450e", "revision": 0, "last_node_id": 110, "last_link_id": 288, "nodes": [{"id": 21, "type": "510f6b52-34ee-40dd-b532-475497dee41b", "pos": [1810, -560], "size": [390, 460], "flags": {}, "order": 0, "mode": 0, "inputs": [{"name": "tags", "type": "STRING", "widget": {"name": "tags"}, "link": null}, {"name": "lyrics", "type": "STRING", "widget": {"name": "lyrics"}, "link": null}, {"name": "timesignature", "type": "COMBO", "widget": {"name": "timesignature"}, "link": null}, {"name": "language", "type": "COMBO", "widget": {"name": "language"}, "link": null}, {"name": "keyscale", "type": "COMBO", "widget": {"name": "keyscale"}, "link": null}, {"name": "generate_audio_codes", "type": "BOOLEAN", "widget": {"name": "generate_audio_codes"}, "link": null}, {"name": "cfg_scale", "type": "FLOAT", "widget": {"name": "cfg_scale"}, "link": null}, {"label": "duration", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "clip_name1", "type": "COMBO", "widget": {"name": "clip_name1"}, "link": null}, {"name": "clip_name2", "type": "COMBO", "widget": {"name": "clip_name2"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"localized_name": "AUDIO", "name": "AUDIO", "type": "AUDIO", "links": []}], "properties": {"proxyWidgets": [["-1", "tags"], ["-1", "lyrics"], ["-1", "language"], ["-1", "timesignature"], ["-1", "keyscale"], ["-1", "generate_audio_codes"], ["-1", "cfg_scale"], ["102", "value"], ["102", "control_after_generate"], ["-1", "unet_name"], ["-1", "clip_name1"], ["-1", "clip_name2"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.12.3", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["", "", "en", "4", "E minor", true, 2, null, null, "acestep_v1.5_turbo.safetensors", "qwen_0.6b_ace15.safetensors", "qwen_4b_ace15.safetensors", "ace_1.5_vae.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "510f6b52-34ee-40dd-b532-475497dee41b", "version": 1, "state": {"lastGroupId": 3, "lastNodeId": 110, "lastLinkId": 288, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Text to Audio (ACE-Step 1.5)", "inputNode": {"id": -10, "bounding": [-660, -560, 167.458984375, 280]}, "outputNode": {"id": -20, "bounding": [1504.8375, -410, 120, 60]}, "inputs": [{"id": "ebc79d17-2e65-4e0f-855a-c9f2466a5fbf", "name": "tags", "type": "STRING", "linkIds": [264], "pos": [-512.541015625, -540]}, {"id": "230afdb4-a647-4fb7-a68c-a2204fd5d570", "name": "lyrics", "type": "STRING", "linkIds": [265], "pos": [-512.541015625, -520]}, {"id": "efdcbb48-231c-4757-b343-4458c011a283", "name": "timesignature", "type": "COMBO", "linkIds": [266], "pos": [-512.541015625, -500]}, {"id": "811579c1-2979-4721-a1e1-7d9352616e7b", "name": "language", "type": "COMBO", "linkIds": [267], "pos": [-512.541015625, -480]}, {"id": "76a68b0d-7a5f-43dc-873d-d78adf32895f", "name": "keyscale", "type": "COMBO", "linkIds": [268], "pos": [-512.541015625, -460]}, {"id": "11bb3297-272d-4c56-873a-2c974581e838", "name": "generate_audio_codes", "type": "BOOLEAN", "linkIds": [269], "pos": [-512.541015625, -440]}, {"id": "e5a30400-a8b0-422a-a0f3-21739727ab03", "name": "cfg_scale", "type": "FLOAT", "linkIds": [270], "pos": [-512.541015625, -420]}, {"id": "91a37ca5-e0d1-42c5-8248-419b850661a0", "name": "value", "type": "FLOAT", "linkIds": [284], "label": "duration", "pos": [-512.541015625, -400]}, {"id": "30f69f59-e916-48ab-9a5d-ae445b8d8a63", "name": "unet_name", "type": "COMBO", "linkIds": [285], "pos": [-512.541015625, -380]}, {"id": "1af0e8df-6fa7-4df2-b1b4-9c356a8f30a6", "name": "clip_name1", "type": "COMBO", "linkIds": [286], "pos": [-512.541015625, -360]}, {"id": "c7195505-9e83-4f87-b8d7-7747d808577d", "name": "clip_name2", "type": "COMBO", "linkIds": [287], "pos": [-512.541015625, -340]}, {"id": "ca4bd68f-e7c1-4d87-9914-cfe15c63b96e", "name": "vae_name", "type": "COMBO", "linkIds": [288], "pos": [-512.541015625, -320]}], "outputs": [{"id": "bfd748f6-f9ac-4588-81fa-41bde07a58fa", "name": "AUDIO", "type": "AUDIO", "linkIds": [263], "localized_name": "AUDIO", "pos": [1524.8375, -390]}], "widgets": [], "nodes": [{"id": 105, "type": "DualCLIPLoader", "pos": [-165, -660], "size": [380, 130], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name1", "name": "clip_name1", "type": "COMBO", "widget": {"name": "clip_name1"}, "link": 286}, {"localized_name": "clip_name2", "name": "clip_name2", "type": "COMBO", "widget": {"name": "clip_name2"}, "link": 287}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [261]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.1", "Node name for S&R": "DualCLIPLoader", "models": [{"name": "qwen_0.6b_ace15.safetensors", "url": "https://huggingface.co/Comfy-Org/ace_step_1.5_ComfyUI_files/resolve/main/split_files/text_encoders/qwen_0.6b_ace15.safetensors", "directory": "text_encoders"}, {"name": "qwen_4b_ace15.safetensors", "url": "https://huggingface.co/Comfy-Org/ace_step_1.5_ComfyUI_files/resolve/main/split_files/text_encoders/qwen_4b_ace15.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_0.6b_ace15.safetensors", "qwen_4b_ace15.safetensors", "ace", "default"]}, {"id": 106, "type": "VAELoader", "pos": [-165, -470], "size": [380, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 288}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [262]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.1", "Node name for S&R": "VAELoader", "models": [{"name": "ace_1.5_vae.safetensors", "url": "https://huggingface.co/Comfy-Org/ace_step_1.5_ComfyUI_files/resolve/main/split_files/vae/ace_1.5_vae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ace_1.5_vae.safetensors"]}, {"id": 98, "type": "EmptyAceStep1.5LatentAudio", "pos": [-150, 10], "size": [314.90390625, 82], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "seconds", "name": "seconds", "type": "FLOAT", "widget": {"name": "seconds"}, "link": 279}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [249]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.1", "Node name for S&R": "EmptyAceStep1.5LatentAudio", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [120, 1]}, {"id": 47, "type": "ConditioningZeroOut", "pos": [670, 50], "size": [204.75, 26], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 255}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [119]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.1", "Node name for S&R": "ConditioningZeroOut", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 3, "type": "KSampler", "pos": [930, -680], "size": [329.39477481889753, 262], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 175}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 254}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 119}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 249}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": 258}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [256]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.1", "Node name for S&R": "KSampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "fixed", 8, 1, "euler", "simple", 1]}, {"id": 78, "type": "ModelSamplingAuraFlow", "pos": [930, -810], "size": [329.39477481889753, 60], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 260}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [175]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.1", "Node name for S&R": "ModelSamplingAuraFlow", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3]}, {"id": 18, "type": "VAEDecodeAudio", "pos": [1280, -800], "size": [164.8375, 46], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 256}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 262}], "outputs": [{"localized_name": "AUDIO", "name": "AUDIO", "type": "AUDIO", "links": [263]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.1", "Node name for S&R": "VAEDecodeAudio", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 94, "type": "TextEncodeAceStepAudio1.5", "pos": [270, -790], "size": [611.9184354063266, 679.7643386829468], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 261}, {"localized_name": "tags", "name": "tags", "type": "STRING", "widget": {"name": "tags"}, "link": 264}, {"localized_name": "lyrics", "name": "lyrics", "type": "STRING", "widget": {"name": "lyrics"}, "link": 265}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": 257}, {"localized_name": "bpm", "name": "bpm", "type": "INT", "widget": {"name": "bpm"}, "link": null}, {"localized_name": "duration", "name": "duration", "type": "FLOAT", "widget": {"name": "duration"}, "link": 280}, {"localized_name": "timesignature", "name": "timesignature", "type": "COMBO", "widget": {"name": "timesignature"}, "link": 266}, {"localized_name": "language", "name": "language", "type": "COMBO", "widget": {"name": "language"}, "link": 267}, {"localized_name": "keyscale", "name": "keyscale", "type": "COMBO", "widget": {"name": "keyscale"}, "link": 268}, {"localized_name": "generate_audio_codes", "name": "generate_audio_codes", "type": "BOOLEAN", "widget": {"name": "generate_audio_codes"}, "link": 269}, {"localized_name": "cfg_scale", "name": "cfg_scale", "type": "FLOAT", "widget": {"name": "cfg_scale"}, "link": 270}, {"localized_name": "temperature", "name": "temperature", "type": "FLOAT", "widget": {"name": "temperature"}, "link": null}, {"localized_name": "top_p", "name": "top_p", "type": "FLOAT", "widget": {"name": "top_p"}, "link": null}, {"localized_name": "top_k", "name": "top_k", "type": "INT", "widget": {"name": "top_k"}, "link": null}, {"localized_name": "min_p", "name": "min_p", "type": "FLOAT", "widget": {"name": "min_p"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [254, 255]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.1", "Node name for S&R": "TextEncodeAceStepAudio1.5", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["", "", 0, "fixed", 190, 120, "4", "en", "E minor", true, 2, 0.85, 0.9, 0, 0]}, {"id": 104, "type": "UNETLoader", "pos": [-170, -790], "size": [380, 82], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 285}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [260]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.1", "Node name for S&R": "UNETLoader", "models": [{"name": "acestep_v1.5_turbo.safetensors", "url": "https://huggingface.co/Comfy-Org/ace_step_1.5_ComfyUI_files/resolve/main/split_files/diffusion_models/acestep_v1.5_turbo.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["acestep_v1.5_turbo.safetensors", "default"]}, {"id": 102, "type": "PrimitiveNode", "pos": [-120, -130], "size": [268.39945903485034, 82], "flags": {}, "order": 3, "mode": 0, "inputs": [], "outputs": [{"name": "INT", "type": "INT", "widget": {"name": "seed"}, "links": [257, 258]}], "title": "seed", "properties": {"Run widget replace on values": false}, "widgets_values": [0, "randomize"]}, {"id": 110, "type": "PrimitiveFloat", "pos": [-120, -280], "size": [270, 58], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": 284}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [279, 280]}], "title": "Song Duration", "properties": {"cnr_id": "comfy-core", "ver": "0.12.3", "Node name for S&R": "PrimitiveFloat", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [120]}], "groups": [{"id": 1, "title": "Step 1 - Load Models", "bounding": [-180, -860, 405, 461.6], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Step 2 - Duration", "bounding": [-180, -370, 400, 170], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Step3 - Prompt", "bounding": [260, -860, 640, 960], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 255, "origin_id": 94, "origin_slot": 0, "target_id": 47, "target_slot": 0, "type": "CONDITIONING"}, {"id": 175, "origin_id": 78, "origin_slot": 0, "target_id": 3, "target_slot": 0, "type": "MODEL"}, {"id": 254, "origin_id": 94, "origin_slot": 0, "target_id": 3, "target_slot": 1, "type": "CONDITIONING"}, {"id": 119, "origin_id": 47, "origin_slot": 0, "target_id": 3, "target_slot": 2, "type": "CONDITIONING"}, {"id": 249, "origin_id": 98, "origin_slot": 0, "target_id": 3, "target_slot": 3, "type": "LATENT"}, {"id": 258, "origin_id": 102, "origin_slot": 0, "target_id": 3, "target_slot": 4, "type": "INT"}, {"id": 260, "origin_id": 104, "origin_slot": 0, "target_id": 78, "target_slot": 0, "type": "MODEL"}, {"id": 256, "origin_id": 3, "origin_slot": 0, "target_id": 18, "target_slot": 0, "type": "LATENT"}, {"id": 262, "origin_id": 106, "origin_slot": 0, "target_id": 18, "target_slot": 1, "type": "VAE"}, {"id": 261, "origin_id": 105, "origin_slot": 0, "target_id": 94, "target_slot": 0, "type": "CLIP"}, {"id": 257, "origin_id": 102, "origin_slot": 0, "target_id": 94, "target_slot": 3, "type": "INT"}, {"id": 263, "origin_id": 18, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "AUDIO"}, {"id": 264, "origin_id": -10, "origin_slot": 0, "target_id": 94, "target_slot": 1, "type": "STRING"}, {"id": 265, "origin_id": -10, "origin_slot": 1, "target_id": 94, "target_slot": 2, "type": "STRING"}, {"id": 266, "origin_id": -10, "origin_slot": 2, "target_id": 94, "target_slot": 6, "type": "COMBO"}, {"id": 267, "origin_id": -10, "origin_slot": 3, "target_id": 94, "target_slot": 7, "type": "COMBO"}, {"id": 268, "origin_id": -10, "origin_slot": 4, "target_id": 94, "target_slot": 8, "type": "COMBO"}, {"id": 269, "origin_id": -10, "origin_slot": 5, "target_id": 94, "target_slot": 9, "type": "BOOLEAN"}, {"id": 270, "origin_id": -10, "origin_slot": 6, "target_id": 94, "target_slot": 10, "type": "FLOAT"}, {"id": 279, "origin_id": 110, "origin_slot": 0, "target_id": 98, "target_slot": 0, "type": "FLOAT"}, {"id": 280, "origin_id": 110, "origin_slot": 0, "target_id": 94, "target_slot": 5, "type": "FLOAT"}, {"id": 284, "origin_id": -10, "origin_slot": 7, "target_id": 110, "target_slot": 0, "type": "FLOAT"}, {"id": 285, "origin_id": -10, "origin_slot": 8, "target_id": 104, "target_slot": 0, "type": "COMBO"}, {"id": 286, "origin_id": -10, "origin_slot": 9, "target_id": 105, "target_slot": 0, "type": "COMBO"}, {"id": 287, "origin_id": -10, "origin_slot": 10, "target_id": 105, "target_slot": 1, "type": "COMBO"}, {"id": 288, "origin_id": -10, "origin_slot": 11, "target_id": 106, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Audio/Music generation"}]}, "config": {}, "extra": {"workflowRendererVersion": "LG", "ds": {"scale": 0.9575633843910519, "offset": [-950.8014851321678, 872.1540230582457]}}, "version": 0.4} diff --git a/blueprints/Text to Image (Z-Image-Turbo).json b/blueprints/Text to Image (Z-Image-Turbo).json new file mode 100644 index 000000000..ce25ce1df --- /dev/null +++ b/blueprints/Text to Image (Z-Image-Turbo).json @@ -0,0 +1 @@ +{"id": "1c3eaa76-5cfa-4dc7-8571-97a570324e01", "revision": 0, "last_node_id": 34, "last_link_id": 40, "nodes": [{"id": 5, "type": "dfe9eb32-97c0-43a5-90d5-4fd37768d91b", "pos": [-2.5766491043910378e-05, 1229.999928629805], "size": [400, 470], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "prompt", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"name": "width", "type": "INT", "widget": {"name": "width"}, "link": null}, {"name": "height", "type": "INT", "widget": {"name": "height"}, "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": []}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "width"], ["-1", "height"], ["3", "seed"], ["3", "control_after_generate"], ["-1", "unet_name"], ["-1", "clip_name"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.3.73", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["", 1024, 1024, null, null, "z_image_turbo_bf16.safetensors", "qwen_3_4b.safetensors", "ae.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "dfe9eb32-97c0-43a5-90d5-4fd37768d91b", "version": 1, "state": {"lastGroupId": 4, "lastNodeId": 34, "lastLinkId": 40, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Text to Image (Z-Image-Turbo)", "inputNode": {"id": -10, "bounding": [-80, 425, 120, 160]}, "outputNode": {"id": -20, "bounding": [1490, 415, 120, 60]}, "inputs": [{"id": "fb178669-e742-4a53-8a69-7df59834dfd8", "name": "text", "type": "STRING", "linkIds": [34], "label": "prompt", "pos": [20, 445]}, {"id": "dd780b3c-23e9-46ff-8469-156008f42e5a", "name": "width", "type": "INT", "linkIds": [35], "pos": [20, 465]}, {"id": "7b08d546-6bb0-4ef9-82e9-ffae5e1ee6bc", "name": "height", "type": "INT", "linkIds": [36], "pos": [20, 485]}, {"id": "23087d15-8412-4fbd-b71e-9b6d7ef76de1", "name": "unet_name", "type": "COMBO", "linkIds": [38], "pos": [20, 505]}, {"id": "0677f5c3-2a3f-43d4-98ac-a4c56d5efdc0", "name": "clip_name", "type": "COMBO", "linkIds": [39], "pos": [20, 525]}, {"id": "c85c0445-2641-48b1-bbca-95057edf2fcf", "name": "vae_name", "type": "COMBO", "linkIds": [40], "pos": [20, 545]}], "outputs": [{"id": "1fa72a21-ce00-4952-814e-1f2ffbe87d1d", "name": "IMAGE", "type": "IMAGE", "linkIds": [16], "localized_name": "IMAGE", "pos": [1510, 435]}], "widgets": [], "nodes": [{"id": 30, "type": "CLIPLoader", "pos": [109.99997264844609, 329.99999029608756], "size": [269.9869791666667, 106], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 39}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "links": [28]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPLoader", "models": [{"name": "qwen_3_4b.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/text_encoders/qwen_3_4b.safetensors", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["qwen_3_4b.safetensors", "lumina2", "default"]}, {"id": 29, "type": "VAELoader", "pos": [109.99997264844609, 479.9999847172637], "size": [269.9869791666667, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 40}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "links": [27]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "VAELoader", "models": [{"name": "ae.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/vae/ae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["ae.safetensors"]}, {"id": 33, "type": "ConditioningZeroOut", "pos": [639.9999103333332, 620.0000271257795], "size": [204.134765625, 26], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "conditioning", "name": "conditioning", "type": "CONDITIONING", "link": 32}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [33]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "ConditioningZeroOut", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 8, "type": "VAEDecode", "pos": [1219.9999088104782, 160.00009184959066], "size": [209.98697916666669, 46], "flags": {}, "order": 5, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 14}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 27}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [16]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": []}, {"id": 28, "type": "UNETLoader", "pos": [109.99997264844609, 200.0000502647102], "size": [269.9869791666667, 82], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 38}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [26]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "UNETLoader", "models": [{"name": "z_image_turbo_bf16.safetensors", "url": "https://huggingface.co/Comfy-Org/z_image_turbo/resolve/main/split_files/diffusion_models/z_image_turbo_bf16.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": ["z_image_turbo_bf16.safetensors", "default"]}, {"id": 27, "type": "CLIPTextEncode", "pos": [429.99997828947767, 200.0000502647102], "size": [409.9869791666667, 319.9869791666667], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 28}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 34}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "links": [30, 32]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.73", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [""]}, {"id": 13, "type": "EmptySD3LatentImage", "pos": [109.99997264844609, 629.9999791384399], "size": [259.9869791666667, 106], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 35}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 36}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [17]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "EmptySD3LatentImage", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [1024, 1024, 1]}, {"id": 3, "type": "KSampler", "pos": [879.9999615530063, 269.9999774911694], "size": [314.9869791666667, 262], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 13}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 30}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 33}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 17}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [14]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "KSampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [0, "randomize", 4, 1, "res_multistep", "simple", 1]}, {"id": 11, "type": "ModelSamplingAuraFlow", "pos": [879.9999615530063, 160.00009184959066], "size": [309.9869791666667, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 26}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.64", "Node name for S&R": "ModelSamplingAuraFlow", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65}, "widgets_values": [3]}], "groups": [{"id": 2, "title": "Image size", "bounding": [100, 560, 290, 200], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Prompt", "bounding": [410, 130, 450, 540], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 4, "title": "Models", "bounding": [100, 130, 290, 413.6], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 32, "origin_id": 27, "origin_slot": 0, "target_id": 33, "target_slot": 0, "type": "CONDITIONING"}, {"id": 26, "origin_id": 28, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "MODEL"}, {"id": 14, "origin_id": 3, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "LATENT"}, {"id": 27, "origin_id": 29, "origin_slot": 0, "target_id": 8, "target_slot": 1, "type": "VAE"}, {"id": 13, "origin_id": 11, "origin_slot": 0, "target_id": 3, "target_slot": 0, "type": "MODEL"}, {"id": 30, "origin_id": 27, "origin_slot": 0, "target_id": 3, "target_slot": 1, "type": "CONDITIONING"}, {"id": 33, "origin_id": 33, "origin_slot": 0, "target_id": 3, "target_slot": 2, "type": "CONDITIONING"}, {"id": 17, "origin_id": 13, "origin_slot": 0, "target_id": 3, "target_slot": 3, "type": "LATENT"}, {"id": 28, "origin_id": 30, "origin_slot": 0, "target_id": 27, "target_slot": 0, "type": "CLIP"}, {"id": 16, "origin_id": 8, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}, {"id": 34, "origin_id": -10, "origin_slot": 0, "target_id": 27, "target_slot": 1, "type": "STRING"}, {"id": 35, "origin_id": -10, "origin_slot": 1, "target_id": 13, "target_slot": 0, "type": "INT"}, {"id": 36, "origin_id": -10, "origin_slot": 2, "target_id": 13, "target_slot": 1, "type": "INT"}, {"id": 38, "origin_id": -10, "origin_slot": 3, "target_id": 28, "target_slot": 0, "type": "COMBO"}, {"id": 39, "origin_id": -10, "origin_slot": 4, "target_id": 30, "target_slot": 0, "type": "COMBO"}, {"id": 40, "origin_id": -10, "origin_slot": 5, "target_id": 29, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image generation and editing/Text to image"}]}, "config": {}, "extra": {"frontendVersion": "1.37.10", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true, "ds": {"scale": 0.8401370345180755, "offset": [940.0587067393087, -830.7121087564725]}}, "version": 0.4} diff --git a/blueprints/Text to Video (Wan 2.2).json b/blueprints/Text to Video (Wan 2.2).json new file mode 100644 index 000000000..9f1b69669 --- /dev/null +++ b/blueprints/Text to Video (Wan 2.2).json @@ -0,0 +1 @@ +{"id": "ec7da562-7e21-4dac-a0d2-f4441e1efd3b", "revision": 0, "last_node_id": 116, "last_link_id": 188, "nodes": [{"id": 114, "type": "59b2f9c7-af11-45c8-a22b-871166f816c0", "pos": [900.0000142553818, 629.999938027585], "size": [400, 394.97395833333337], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "prompt", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}, {"name": "length", "type": "INT", "widget": {"name": "length"}, "link": null}, {"name": "width", "type": "INT", "widget": {"name": "width"}, "link": null}, {"name": "height", "type": "INT", "widget": {"name": "height"}, "link": null}], "outputs": [{"name": "VIDEO", "type": "VIDEO", "links": null}], "properties": {"proxyWidgets": [["-1", "text"], ["-1", "length"], ["-1", "width"], ["-1", "height"], ["81", "noise_seed"], ["81", "control_after_generate"]], "cnr_id": "comfy-core", "ver": "0.11.0"}, "widgets_values": ["", 81, 640, 640]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "59b2f9c7-af11-45c8-a22b-871166f816c0", "version": 1, "state": {"lastGroupId": 15, "lastNodeId": 114, "lastLinkId": 196, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Text to Video (Wan 2.2)", "inputNode": {"id": -10, "bounding": [-99.66668418897854, 621.3333300391974, 120, 120]}, "outputNode": {"id": -20, "bounding": [1661.9927561248032, 500.2133490758798, 120, 60]}, "inputs": [{"id": "3a15ef44-456f-4a3a-ade7-7a0840166830", "name": "text", "type": "STRING", "linkIds": [189], "label": "prompt", "pos": [0.333315811021464, 641.3333300391974]}, {"id": "ec76f1bf-b130-4dc9-a50c-0b10002725d6", "name": "length", "type": "INT", "linkIds": [190], "pos": [0.333315811021464, 661.3333300391974]}, {"id": "1abb6b00-a8b4-4e72-9d87-53f1fc5d281e", "name": "width", "type": "INT", "linkIds": [191], "pos": [0.333315811021464, 681.3333300391974]}, {"id": "0af36ab5-ee95-4ce5-9ad9-26436319a0d2", "name": "height", "type": "INT", "linkIds": [192], "pos": [0.333315811021464, 701.3333300391974]}], "outputs": [{"id": "6bdfda51-5568-48bf-8985-dbad1e11b3d8", "name": "VIDEO", "type": "VIDEO", "linkIds": [196], "pos": [1681.9927561248032, 520.2133490758798]}], "widgets": [], "nodes": [{"id": 71, "type": "CLIPLoader", "pos": [50.33329119280961, 51.33334121884377], "size": [346.38020833333337, 98], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "slot_index": 0, "links": [141, 160]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "CLIPLoader", "models": [{"name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", "directory": "text_encoders"}]}, "widgets_values": ["umt5_xxl_fp8_e4m3fn_scaled.safetensors", "wan", "default"]}, {"id": 73, "type": "VAELoader", "pos": [50.33329119280961, 211.33336855035554], "size": [344.7135416666667, 50], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [158]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "VAELoader", "models": [{"name": "wan_2.1_vae.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors", "directory": "vae"}]}, "widgets_values": ["wan_2.1_vae.safetensors"]}, {"id": 76, "type": "UNETLoader", "pos": [50.33329119280961, -78.66666636275716], "size": [346.7447916666667, 74], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [155]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "UNETLoader", "models": [{"name": "wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors", "directory": "diffusion_models"}]}, "widgets_values": ["wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors", "default"]}, {"id": 75, "type": "UNETLoader", "pos": [50.33329119280961, -208.66667394435814], "size": [346.7447916666667, 74], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [153]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "UNETLoader", "models": [{"name": "wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors", "directory": "diffusion_models"}]}, "widgets_values": ["wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors", "default"]}, {"id": 83, "type": "LoraLoaderModelOnly", "pos": [450.3332425195698, -198.66662836038148], "size": [279.9869791666667, 74], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 153}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": null}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [152]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.49", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "wan2.2_t2v_lightx2v_4steps_lora_v1.1_high_noise.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/loras/wan2.2_t2v_lightx2v_4steps_lora_v1.1_high_noise.safetensors", "directory": "loras"}]}, "widgets_values": ["wan2.2_t2v_lightx2v_4steps_lora_v1.1_high_noise.safetensors", 1.0000000000000002]}, {"id": 85, "type": "LoraLoaderModelOnly", "pos": [450.3332425195698, -58.66669219682302], "size": [279.9869791666667, 74], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 155}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": null}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [156]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.49", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "wan2.2_t2v_lightx2v_4steps_lora_v1.1_low_noise.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/loras/wan2.2_t2v_lightx2v_4steps_lora_v1.1_low_noise.safetensors", "directory": "loras"}]}, "widgets_values": ["wan2.2_t2v_lightx2v_4steps_lora_v1.1_low_noise.safetensors", 1.0000000000000002]}, {"id": 86, "type": "ModelSamplingSD3", "pos": [740.3332774326827, -58.66669219682302], "size": [210, 50], "flags": {"collapsed": false}, "order": 9, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 156}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [183]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "ModelSamplingSD3"}, "widgets_values": [5.000000000000001]}, {"id": 82, "type": "ModelSamplingSD3", "pos": [740.3332774326827, -198.66662836038148], "size": [210, 50], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 152}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [181]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "ModelSamplingSD3"}, "widgets_values": [5.000000000000001]}, {"id": 81, "type": "KSamplerAdvanced", "pos": [990.3333640139272, -248.66668077723608], "size": [300, 440.98958333333337], "flags": {}, "order": 13, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 181}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 149}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 150}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 151}, {"localized_name": "add_noise", "name": "add_noise", "type": "COMBO", "widget": {"name": "add_noise"}, "link": null}, {"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "start_at_step", "name": "start_at_step", "type": "INT", "widget": {"name": "start_at_step"}, "link": null}, {"localized_name": "end_at_step", "name": "end_at_step", "type": "INT", "widget": {"name": "end_at_step"}, "link": null}, {"localized_name": "return_with_leftover_noise", "name": "return_with_leftover_noise", "type": "COMBO", "widget": {"name": "return_with_leftover_noise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [145]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "KSamplerAdvanced"}, "widgets_values": ["enable", 0, "randomize", 4, 1, "euler", "simple", 0, 2, "enable"]}, {"id": 74, "type": "EmptyHunyuanLatentVideo", "pos": [70.33326535874369, 381.33332446382485], "size": [314.9869791666667, 122], "flags": {}, "order": 11, "mode": 0, "inputs": [{"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 191}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 192}, {"localized_name": "length", "name": "length", "type": "INT", "widget": {"name": "length"}, "link": 190}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [151]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "EmptyHunyuanLatentVideo"}, "widgets_values": [640, 640, 81, 1]}, {"id": 78, "type": "KSamplerAdvanced", "pos": [1310.3334186769505, -248.66668077723608], "size": [304.73958333333337, 440.98958333333337], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 183}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 143}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 144}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 145}, {"localized_name": "add_noise", "name": "add_noise", "type": "COMBO", "widget": {"name": "add_noise"}, "link": null}, {"localized_name": "noise_seed", "name": "noise_seed", "type": "INT", "widget": {"name": "noise_seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "start_at_step", "name": "start_at_step", "type": "INT", "widget": {"name": "start_at_step"}, "link": null}, {"localized_name": "end_at_step", "name": "end_at_step", "type": "INT", "widget": {"name": "end_at_step"}, "link": null}, {"localized_name": "return_with_leftover_noise", "name": "return_with_leftover_noise", "type": "COMBO", "widget": {"name": "return_with_leftover_noise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [157]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "KSamplerAdvanced"}, "widgets_values": ["disable", 0, "fixed", 4, 1, "euler", "simple", 2, 4, "disable"]}, {"id": 114, "type": "CreateVideo", "pos": [1320.333347258908, 441.33336396364655], "size": [269.9869791666667, 70], "flags": {}, "order": 16, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 195}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [196]}], "properties": {"cnr_id": "comfy-core", "ver": "0.11.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [16]}, {"id": 112, "type": "Note", "pos": [30.33320002485607, -428.6666237736725], "size": [359.9869791666667, 52], "flags": {}, "order": 4, "mode": 0, "inputs": [], "outputs": [], "title": "About 4 Steps LoRA", "properties": {}, "widgets_values": ["Using the Wan2.2 Lighting LoRA will result in the loss of video dynamics, but it will reduce the generation time. This template provides two workflows, and you can enable one as needed."], "color": "#222", "bgcolor": "#000"}, {"id": 62, "type": "MarkdownNote", "pos": [-489.666771800538, -278.666700527147], "size": [479.9869791666667, 542.1354166666667], "flags": {}, "order": 5, "mode": 0, "inputs": [], "outputs": [], "title": "Model Links", "properties": {}, "widgets_values": ["[Tutorial](https://docs.comfy.org/tutorials/video/wan/wan2_2\n) | [教程](https://docs.comfy.org/zh-CN/tutorials/video/wan/wan2_2\n)\n\n**Diffusion Model** \n- [wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors)\n- [wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/diffusion_models/wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors)\n\n**LoRA**\n\n- [wan2.2_t2v_lightx2v_4steps_lora_v1.1_high_noise.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/loras/wan2.2_t2v_lightx2v_4steps_lora_v1.1_high_noise.safetensors)\n- [wan2.2_t2v_lightx2v_4steps_lora_v1.1_low_noise.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/loras/wan2.2_t2v_lightx2v_4steps_lora_v1.1_low_noise.safetensors)\n\n**VAE**\n- [wan_2.1_vae.safetensors](https://huggingface.co/Comfy-Org/Wan_2.2_ComfyUI_Repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors)\n\n**Text Encoder** \n- [umt5_xxl_fp8_e4m3fn_scaled.safetensors](https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors)\n\n\nFile save location\n\n```\nComfyUI/\n├───📂 models/\n│ ├───📂 diffusion_models/\n│ │ ├─── wan2.2_t2v_low_noise_14B_fp8_scaled.safetensors\n│ │ └─── wan2.2_t2v_high_noise_14B_fp8_scaled.safetensors\n│ ├───📂 loras/\n│ │ ├───wan2.2_t2v_lightx2v_4steps_lora_v1.1_low_noise.safetensors\n│ │ └───wan2.2_t2v_lightx2v_4steps_lora_v1.1_high_noise.safetensors\n│ ├───📂 text_encoders/\n│ │ └─── umt5_xxl_fp8_e4m3fn_scaled.safetensors \n│ └───📂 vae/\n│ └── wan_2.1_vae.safetensors\n```\n"], "color": "#222", "bgcolor": "#000"}, {"id": 87, "type": "VAEDecode", "pos": [1020.3331497597994, 471.3333837135574], "size": [210, 46], "flags": {"collapsed": false}, "order": 14, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 157}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 158}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [195]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "VAEDecode"}, "widgets_values": []}, {"id": 72, "type": "CLIPTextEncode", "pos": [440.3333139376125, 331.3333305479798], "size": [500, 170], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 141}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [144, 150]}], "title": "CLIP Text Encode (Negative Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "CLIPTextEncode"}, "widgets_values": ["色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,裸露,NSFW"], "color": "#322", "bgcolor": "#533"}, {"id": 89, "type": "CLIPTextEncode", "pos": [440.3333139376125, 131.33323788258042], "size": [510, 170], "flags": {}, "order": 15, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 160}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": 189}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [143, 149]}], "title": "CLIP Text Encode (Positive Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.45", "Node name for S&R": "CLIPTextEncode"}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}], "groups": [{"id": 13, "title": "Wan2.2 T2V fp8_scaled + 4 steps LoRA", "bounding": [31.999982477688036, -317.00000329413615, 1610, 880], "color": "#444", "font_size": 24, "flags": {}}, {"id": 6, "title": "Step3 Prompt", "bounding": [431.99998247768815, 57.99999670586385, 530, 460], "color": "#444", "font_size": 24, "flags": {}}, {"id": 7, "title": "Lightx2v 4steps LoRA", "bounding": [431.99998247768815, -275.33333662746946, 530, 320], "color": "#444", "font_size": 24, "flags": {}}, {"id": 11, "title": "Step 1 - Load models", "bounding": [40.33331581102152, -275.33333662746946, 366.7470703125, 563.5814208984375], "color": "#444", "font_size": 24, "flags": {}}, {"id": 12, "title": "Step 2 - Video size", "bounding": [40.33331581102152, 299.6666633725306, 370, 230], "color": "#444", "font_size": 24, "flags": {}}], "links": [{"id": 153, "origin_id": 75, "origin_slot": 0, "target_id": 83, "target_slot": 0, "type": "MODEL"}, {"id": 155, "origin_id": 76, "origin_slot": 0, "target_id": 85, "target_slot": 0, "type": "MODEL"}, {"id": 156, "origin_id": 85, "origin_slot": 0, "target_id": 86, "target_slot": 0, "type": "MODEL"}, {"id": 152, "origin_id": 83, "origin_slot": 0, "target_id": 82, "target_slot": 0, "type": "MODEL"}, {"id": 160, "origin_id": 71, "origin_slot": 0, "target_id": 89, "target_slot": 0, "type": "CLIP"}, {"id": 181, "origin_id": 82, "origin_slot": 0, "target_id": 81, "target_slot": 0, "type": "MODEL"}, {"id": 149, "origin_id": 89, "origin_slot": 0, "target_id": 81, "target_slot": 1, "type": "CONDITIONING"}, {"id": 150, "origin_id": 72, "origin_slot": 0, "target_id": 81, "target_slot": 2, "type": "CONDITIONING"}, {"id": 151, "origin_id": 74, "origin_slot": 0, "target_id": 81, "target_slot": 3, "type": "LATENT"}, {"id": 157, "origin_id": 78, "origin_slot": 0, "target_id": 87, "target_slot": 0, "type": "LATENT"}, {"id": 158, "origin_id": 73, "origin_slot": 0, "target_id": 87, "target_slot": 1, "type": "VAE"}, {"id": 141, "origin_id": 71, "origin_slot": 0, "target_id": 72, "target_slot": 0, "type": "CLIP"}, {"id": 183, "origin_id": 86, "origin_slot": 0, "target_id": 78, "target_slot": 0, "type": "MODEL"}, {"id": 143, "origin_id": 89, "origin_slot": 0, "target_id": 78, "target_slot": 1, "type": "CONDITIONING"}, {"id": 144, "origin_id": 72, "origin_slot": 0, "target_id": 78, "target_slot": 2, "type": "CONDITIONING"}, {"id": 145, "origin_id": 81, "origin_slot": 0, "target_id": 78, "target_slot": 3, "type": "LATENT"}, {"id": 189, "origin_id": -10, "origin_slot": 0, "target_id": 89, "target_slot": 1, "type": "STRING"}, {"id": 190, "origin_id": -10, "origin_slot": 1, "target_id": 74, "target_slot": 2, "type": "INT"}, {"id": 191, "origin_id": -10, "origin_slot": 2, "target_id": 74, "target_slot": 0, "type": "INT"}, {"id": 192, "origin_id": -10, "origin_slot": 3, "target_id": 74, "target_slot": 1, "type": "INT"}, {"id": 195, "origin_id": 87, "origin_slot": 0, "target_id": 114, "target_slot": 0, "type": "IMAGE"}, {"id": 196, "origin_id": 114, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Text to video"}]}, "config": {}, "extra": {"frontendVersion": "1.37.10", "workflowRendererVersion": "LG", "VHS_latentpreview": false, "VHS_latentpreviewrate": 0, "VHS_MetadataImage": true, "VHS_KeepIntermediate": true}, "version": 0.4} diff --git a/blueprints/Unsharp Mask.json b/blueprints/Unsharp Mask.json new file mode 100644 index 000000000..9363037ef --- /dev/null +++ b/blueprints/Unsharp Mask.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 30, "last_link_id": 0, "nodes": [{"id": 30, "type": "d99ba3f5-8a56-4365-8e45-3f3ea7c572a1", "pos": [4420, -370], "size": [210, 106], "flags": {}, "order": 4, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": []}], "title": "Unsharp Mask", "properties": {"proxyWidgets": [["27", "value"], ["28", "value"], ["29", "value"]]}, "widgets_values": []}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "d99ba3f5-8a56-4365-8e45-3f3ea7c572a1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 29, "lastLinkId": 43, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Unsharp Mask", "inputNode": {"id": -10, "bounding": [3920, -405, 120, 60]}, "outputNode": {"id": -20, "bounding": [4930, -405, 120, 60]}, "inputs": [{"id": "75354555-d2f3-46b9-a3dd-b076dcfca561", "name": "images.image0", "type": "IMAGE", "linkIds": [39], "localized_name": "images.image0", "label": "image0", "pos": [4020, -385]}], "outputs": [{"id": "04368b94-2a96-46ff-8c07-d0ce3235b40d", "name": "IMAGE0", "type": "IMAGE", "linkIds": [40], "localized_name": "IMAGE0", "pos": [4950, -385]}], "widgets": [], "nodes": [{"id": 27, "type": "PrimitiveFloat", "pos": [4100, -540], "size": [270, 58], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "amount", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [41]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 3, "precision": 2, "step": 0.05}, "widgets_values": [1]}, {"id": 28, "type": "PrimitiveFloat", "pos": [4100, -430], "size": [270, 58], "flags": {}, "order": 1, "mode": 0, "inputs": [{"label": "radius", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [42]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 10, "precision": 1, "step": 0.5}, "widgets_values": [3]}, {"id": 29, "type": "PrimitiveFloat", "pos": [4100, -320], "size": [270, 58], "flags": {}, "order": 2, "mode": 0, "inputs": [{"label": "threshold", "localized_name": "value", "name": "value", "type": "FLOAT", "widget": {"name": "value"}, "link": null}], "outputs": [{"localized_name": "FLOAT", "name": "FLOAT", "type": "FLOAT", "links": [43]}], "properties": {"Node name for S&R": "PrimitiveFloat", "min": 0, "max": 1, "precision": 2, "step": 0.05}, "widgets_values": [0]}, {"id": 26, "type": "GLSLShader", "pos": [4470, -580], "size": [400, 232], "flags": {}, "order": 3, "mode": 0, "inputs": [{"label": "image0", "localized_name": "images.image0", "name": "images.image0", "type": "IMAGE", "link": 39}, {"label": "image1", "localized_name": "images.image1", "name": "images.image1", "shape": 7, "type": "IMAGE", "link": null}, {"label": "u_float0", "localized_name": "floats.u_float0", "name": "floats.u_float0", "shape": 7, "type": "FLOAT", "link": 41}, {"label": "u_float1", "localized_name": "floats.u_float1", "name": "floats.u_float1", "shape": 7, "type": "FLOAT", "link": 42}, {"label": "u_float2", "localized_name": "floats.u_float2", "name": "floats.u_float2", "shape": 7, "type": "FLOAT", "link": 43}, {"label": "u_float3", "localized_name": "floats.u_float3", "name": "floats.u_float3", "shape": 7, "type": "FLOAT", "link": null}, {"label": "u_int0", "localized_name": "ints.u_int0", "name": "ints.u_int0", "shape": 7, "type": "INT", "link": null}, {"localized_name": "fragment_shader", "name": "fragment_shader", "type": "STRING", "widget": {"name": "fragment_shader"}, "link": null}, {"localized_name": "size_mode", "name": "size_mode", "type": "COMFY_DYNAMICCOMBO_V3", "widget": {"name": "size_mode"}, "link": null}], "outputs": [{"localized_name": "IMAGE0", "name": "IMAGE0", "type": "IMAGE", "links": [40]}, {"localized_name": "IMAGE1", "name": "IMAGE1", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE2", "name": "IMAGE2", "type": "IMAGE", "links": null}, {"localized_name": "IMAGE3", "name": "IMAGE3", "type": "IMAGE", "links": null}], "properties": {"Node name for S&R": "GLSLShader"}, "widgets_values": ["#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n", "from_input"]}], "groups": [], "links": [{"id": 41, "origin_id": 27, "origin_slot": 0, "target_id": 26, "target_slot": 2, "type": "FLOAT"}, {"id": 42, "origin_id": 28, "origin_slot": 0, "target_id": 26, "target_slot": 3, "type": "FLOAT"}, {"id": 43, "origin_id": 29, "origin_slot": 0, "target_id": 26, "target_slot": 4, "type": "FLOAT"}, {"id": 39, "origin_id": -10, "origin_slot": 0, "target_id": 26, "target_slot": 0, "type": "IMAGE"}, {"id": 40, "origin_id": 26, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "IMAGE"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Image Tools/Sharpen"}]}} diff --git a/blueprints/Video Captioning (Gemini).json b/blueprints/Video Captioning (Gemini).json new file mode 100644 index 000000000..1d72718a1 --- /dev/null +++ b/blueprints/Video Captioning (Gemini).json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 233, "last_link_id": 0, "nodes": [{"id": 233, "type": "dcf32045-0ee4-4efc-9aca-9f26f3a157be", "pos": [0, 1140], "size": [400, 260], "flags": {}, "order": 7, "mode": 0, "inputs": [{"name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": null}, {"name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": null}, {"name": "video", "type": "VIDEO", "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": []}], "title": "Video Captioning(Gemini)", "properties": {"proxyWidgets": [["-1", "prompt"], ["-1", "model"], ["1", "seed"]], "cnr_id": "comfy-core", "ver": "0.13.0"}, "widgets_values": ["Describe this video", "gemini-2.5-pro"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "dcf32045-0ee4-4efc-9aca-9f26f3a157be", "version": 1, "state": {"lastGroupId": 1, "lastNodeId": 16, "lastLinkId": 17, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Captioning(Gemini)", "inputNode": {"id": -10, "bounding": [-6870, 2530, 120, 100]}, "outputNode": {"id": -20, "bounding": [-6240, 2530, 120, 60]}, "inputs": [{"id": "d8cbd7eb-636a-4d7b-8ff6-b22f1755e26c", "name": "prompt", "type": "STRING", "linkIds": [15], "pos": [-6770, 2550]}, {"id": "b034e26a-d114-4604-aec2-32783e86aa6b", "name": "model", "type": "COMBO", "linkIds": [16], "pos": [-6770, 2570]}, {"id": "f7363f60-a106-4e06-90af-df5f53355b98", "name": "video", "type": "VIDEO", "linkIds": [17], "pos": [-6770, 2590]}], "outputs": [{"id": "e12c6e80-5210-4328-a581-bc8924c53070", "name": "STRING", "type": "STRING", "linkIds": [6], "localized_name": "STRING", "pos": [-6220, 2550]}], "widgets": [], "nodes": [{"id": 1, "type": "GeminiNode", "pos": [-6690, 2360], "size": [390, 430], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "shape": 7, "type": "IMAGE", "link": null}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": null}, {"localized_name": "video", "name": "video", "shape": 7, "type": "VIDEO", "link": 17}, {"localized_name": "files", "name": "files", "shape": 7, "type": "GEMINI_INPUT_FILES", "link": null}, {"localized_name": "prompt", "name": "prompt", "type": "STRING", "widget": {"name": "prompt"}, "link": 15}, {"localized_name": "model", "name": "model", "type": "COMBO", "widget": {"name": "model"}, "link": 16}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "system_prompt", "name": "system_prompt", "shape": 7, "type": "STRING", "widget": {"name": "system_prompt"}, "link": null}], "outputs": [{"localized_name": "STRING", "name": "STRING", "type": "STRING", "links": [6]}], "properties": {"cnr_id": "comfy-core", "ver": "0.5.1", "Node name for S&R": "GeminiNode"}, "widgets_values": ["Describe this video", "gemini-2.5-pro", 511865409297955, "randomize", "- Role: AI Video Analysis and Description Specialist\n- Background: The user requires a prompt that enables AI to analyze videos (including frame sequences, dynamic movements, audio-visual elements) and generate detailed, structured descriptions. These descriptions must be directly usable as video generation prompts to create similar videos, serving core tasks such as video content creation, creative inspiration extraction, and artistic style exploration.\n- Profile: As an AI Video Analysis and Description Specialist, you possess expertise in computer vision, video temporal sequence processing, motion analysis, and multi-modal natural language generation. You excel at interpreting dynamic visual data (frame-by-frame features + continuous motion) and translating it into precise descriptive text that fully guides the creation of new videos with matching style, rhythm, and content.\n- Skills: Proficiency in video frame feature extraction, motion trajectory recognition, temporal rhythm analysis, scene/shot segmentation, color grading detection, camera movement identification (pan/tilt/zoom/dolly), audio-visual element correlation analysis, and descriptive language generation that captures both static visual features and dynamic temporal characteristics. Mastery of artistic elements in video: composition (per frame + dynamic framing), color palette (consistent + transitional), texture (surface details + motion blur), pacing (frame rate, shot duration), and sound style (background music, ambient sound cues).\n- Goals: To analyze the provided video comprehensively, generate a detailed, structured description that captures all key video elements (static visual features + dynamic motion/temporal characteristics + audio-visual style), and ensure this description can directly serve as a high-quality prompt for creating similar videos.\n- Constraints: \n 1. The description must be clear, structured, and specific enough to guide end-to-end video creation (including frame rate, shot duration, camera movement, motion speed, color transitions).\n 2. Avoid ambiguity; focus on the most salient static (per-frame) and dynamic (temporal) features of the video.\n 3. Prioritize video-specific elements: motion trajectory, shot types (close-up/wide shot/etc.), camera movement, frame rate, scene transitions, rhythm/pacing, and temporal color changes.\n 4. The output must only contain the video generation prompt (no extra explanations).\n- OutputFormat: A detailed, hierarchical text description of the video, structured as follows:\n 1. Core Content & Narrative: Brief overview of the video's subject and temporal progression\n 2. Visual Style (Static): Per-frame key elements (objects, colors, composition, lighting, texture)\n 3. Dynamic Elements (Temporal): Motion details (speed, trajectory, direction), camera movement (type, speed, direction), shot duration/frame rate, scene transitions\n 4. Audio-Visual Style: Color grading (consistent/transitional), rhythm/pacing, and implied audio style (if discernible)\n- Workflow:\n 1. Analyze the video to segment shots/scenes, identify frame-by-frame static visual elements (objects, colors, composition) and cross-frame dynamic elements (motion, camera movement, temporal changes).\n 2. Extract video-specific technical features: frame rate, shot duration, scene transition types, motion speed/rhythm.\n 3. Generate a structured, detailed description that captures the essence of the video (static + dynamic + temporal characteristics), ensuring specificity and actionability for video generation.\n 4. Refine the description for clarity, conciseness, and alignment with video generation prompt norms (e.g., including frame rate, camera movement terms, motion speed descriptors)."], "color": "#432", "bgcolor": "#653"}], "groups": [], "links": [{"id": 6, "origin_id": 1, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "*"}, {"id": 15, "origin_id": -10, "origin_slot": 0, "target_id": 1, "target_slot": 4, "type": "STRING"}, {"id": 16, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 5, "type": "COMBO"}, {"id": 17, "origin_id": -10, "origin_slot": 2, "target_id": 1, "target_slot": 2, "type": "VIDEO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Text generation/Video Captioning"}]}} diff --git a/blueprints/Video Inpaint(Wan2.1 VACE).json b/blueprints/Video Inpaint(Wan2.1 VACE).json new file mode 100644 index 000000000..a7c6db003 --- /dev/null +++ b/blueprints/Video Inpaint(Wan2.1 VACE).json @@ -0,0 +1 @@ +{"id": "2f429c60-2e03-4117-908b-31e1fab04bba", "revision": 0, "last_node_id": 229, "last_link_id": 366, "nodes": [{"id": 229, "type": "53a657f3-c9eb-40f2-9ebd-1ed77d25ed67", "pos": [-230, 160], "size": [400, 480], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "video mask", "localized_name": "mask", "name": "mask", "type": "MASK", "link": null}, {"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "width", "type": "INT", "widget": {"name": "width"}, "link": null}, {"name": "height", "type": "INT", "widget": {"name": "height"}, "link": null}, {"label": "reference image", "name": "reference_image_1", "type": "IMAGE", "link": null}, {"name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": null}, {"name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": null}, {"name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": null}, {"name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "properties": {"proxyWidgets": [["6", "text"], ["-1", "width"], ["-1", "height"], ["3", "seed"], ["3", "control_after_generate"], ["-1", "unet_name"], ["-1", "lora_name"], ["-1", "clip_name"], ["-1", "vae_name"]], "cnr_id": "comfy-core", "ver": "0.13.0"}, "widgets_values": [null, 720, 720, null, null, "wan2.1_vace_14B_fp16.safetensors", "Wan21_CausVid_14B_T2V_lora_rank32.safetensors", "umt5_xxl_fp8_e4m3fn_scaled.safetensors", "wan_2.1_vae.safetensors"]}], "links": [], "groups": [], "definitions": {"subgraphs": [{"id": "53a657f3-c9eb-40f2-9ebd-1ed77d25ed67", "version": 1, "state": {"lastGroupId": 25, "lastNodeId": 229, "lastLinkId": 366, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "local-Video Inpaint(Wan2.1 VACE)", "inputNode": {"id": -10, "bounding": [-970, 800, 132.54296875, 220]}, "outputNode": {"id": -20, "bounding": [1480, 535, 120, 60]}, "inputs": [{"id": "9fdda38d-6aa7-48ad-b425-f493d8aa585c", "name": "mask", "type": "MASK", "linkIds": [351, 335, 345], "localized_name": "mask", "label": "video mask", "pos": [-857.45703125, 820]}, {"id": "8b1788cc-46d2-4f40-8b33-70fd56b4cb24", "name": "video", "type": "VIDEO", "linkIds": [336], "localized_name": "video", "pos": [-857.45703125, 840]}, {"id": "09393f21-257e-4476-bb02-54899a8252b8", "name": "width", "type": "INT", "linkIds": [355], "pos": [-857.45703125, 860]}, {"id": "07a030f7-7eac-4b3f-b8f3-f00ee87b191d", "name": "height", "type": "INT", "linkIds": [356], "pos": [-857.45703125, 880]}, {"id": "255908d3-6cc9-48fc-b76b-ab9fb72695bc", "name": "reference_image_1", "type": "IMAGE", "linkIds": [361], "label": "reference image", "pos": [-857.45703125, 900]}, {"id": "18a5d241-523c-433d-ae05-25b6e69d1e29", "name": "unet_name", "type": "COMBO", "linkIds": [363], "pos": [-857.45703125, 920]}, {"id": "d7576e1b-da5f-402f-81b2-d37f838b1f8f", "name": "lora_name", "type": "COMBO", "linkIds": [364], "pos": [-857.45703125, 940]}, {"id": "41676a3e-c710-4723-821e-f651ad3784b1", "name": "clip_name", "type": "COMBO", "linkIds": [365], "pos": [-857.45703125, 960]}, {"id": "41fc878c-9aa6-4c12-bef3-ceda6b094b7c", "name": "vae_name", "type": "COMBO", "linkIds": [366], "pos": [-857.45703125, 980]}], "outputs": [{"id": "d4861f39-1011-49dc-80fd-ee318b614a8d", "name": "VIDEO", "type": "VIDEO", "linkIds": [129], "localized_name": "VIDEO", "pos": [1500, 555]}], "widgets": [], "nodes": [{"id": 58, "type": "TrimVideoLatent", "pos": [760, 390], "size": [315, 60], "flags": {"collapsed": false}, "order": 13, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 116}, {"localized_name": "trim_amount", "name": "trim_amount", "type": "INT", "widget": {"name": "trim_amount"}, "link": 115}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "links": [117]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "TrimVideoLatent", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {"trim_amount": true}}, "widgets_values": [0]}, {"id": 8, "type": "VAEDecode", "pos": [770, 500], "size": [315, 46], "flags": {"collapsed": false}, "order": 11, "mode": 0, "inputs": [{"localized_name": "samples", "name": "samples", "type": "LATENT", "link": 117}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 76}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "slot_index": 0, "links": [139]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAEDecode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 48, "type": "ModelSamplingSD3", "pos": [400, 50], "size": [315, 58], "flags": {}, "order": 9, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 279}, {"localized_name": "shift", "name": "shift", "type": "FLOAT", "widget": {"name": "shift"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [280]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "ModelSamplingSD3", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": [5]}, {"id": 219, "type": "InvertMask", "pos": [400, 990], "size": [140, 26], "flags": {}, "order": 24, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 351}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [352]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.40", "Node name for S&R": "InvertMask"}, "widgets_values": []}, {"id": 216, "type": "MaskToImage", "pos": [560, 990], "size": [193.2779296875, 26], "flags": {}, "order": 23, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 352}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [334]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.40", "Node name for S&R": "MaskToImage"}, "widgets_values": []}, {"id": 213, "type": "RebatchImages", "pos": [410, 690], "size": [230, 60], "flags": {}, "order": 21, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 360}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": 340}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "shape": 6, "type": "IMAGE", "links": [333]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.40", "Node name for S&R": "RebatchImages"}, "widgets_values": [1]}, {"id": 68, "type": "CreateVideo", "pos": [1150, 50], "size": [270, 78], "flags": {"collapsed": false}, "order": 14, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 139}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 362}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 353}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [129]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "CreateVideo", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": [16]}, {"id": 208, "type": "ImageCompositeMasked", "pos": [410, 790], "size": [230, 146], "flags": {}, "order": 18, "mode": 0, "inputs": [{"localized_name": "destination", "name": "destination", "type": "IMAGE", "link": 333}, {"localized_name": "source", "name": "source", "type": "IMAGE", "link": 334}, {"localized_name": "mask", "name": "mask", "shape": 7, "type": "MASK", "link": 335}, {"localized_name": "x", "name": "x", "type": "INT", "widget": {"name": "x"}, "link": null}, {"localized_name": "y", "name": "y", "type": "INT", "widget": {"name": "y"}, "link": null}, {"localized_name": "resize_source", "name": "resize_source", "type": "BOOLEAN", "widget": {"name": "resize_source"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [341, 344]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.40", "Node name for S&R": "ImageCompositeMasked"}, "widgets_values": [0, 0, true]}, {"id": 214, "type": "PreviewImage", "pos": [760, 690], "size": [300, 300], "flags": {}, "order": 22, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 341}], "outputs": [], "properties": {"cnr_id": "comfy-core", "ver": "0.3.40", "Node name for S&R": "PreviewImage"}, "widgets_values": []}, {"id": 111, "type": "MaskToImage", "pos": [20, 1270], "size": [240, 26], "flags": {}, "order": 15, "mode": 0, "inputs": [{"localized_name": "mask", "name": "mask", "type": "MASK", "link": 345}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [201]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "MaskToImage", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": []}, {"id": 129, "type": "RepeatImageBatch", "pos": [20, 1160], "size": [240, 60], "flags": {}, "order": 16, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 201}, {"localized_name": "amount", "name": "amount", "type": "INT", "widget": {"name": "amount"}, "link": 346}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [202]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "RepeatImageBatch", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {"amount": true}}, "widgets_values": [17]}, {"id": 130, "type": "ImageToMask", "pos": [20, 1050], "size": [240, 60], "flags": {}, "order": 17, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 202}, {"localized_name": "channel", "name": "channel", "type": "COMBO", "widget": {"name": "channel"}, "link": null}], "outputs": [{"localized_name": "MASK", "name": "MASK", "type": "MASK", "links": [349]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "ImageToMask", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["red"]}, {"id": 3, "type": "KSampler", "pos": [770, 50], "size": [315, 262], "flags": {}, "order": 10, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 280}, {"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 98}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 99}, {"localized_name": "latent_image", "name": "latent_image", "type": "LATENT", "link": 160}, {"localized_name": "seed", "name": "seed", "type": "INT", "widget": {"name": "seed"}, "link": null}, {"localized_name": "steps", "name": "steps", "type": "INT", "widget": {"name": "steps"}, "link": null}, {"localized_name": "cfg", "name": "cfg", "type": "FLOAT", "widget": {"name": "cfg"}, "link": null}, {"localized_name": "sampler_name", "name": "sampler_name", "type": "COMBO", "widget": {"name": "sampler_name"}, "link": null}, {"localized_name": "scheduler", "name": "scheduler", "type": "COMBO", "widget": {"name": "scheduler"}, "link": null}, {"localized_name": "denoise", "name": "denoise", "type": "FLOAT", "widget": {"name": "denoise"}, "link": null}], "outputs": [{"localized_name": "LATENT", "name": "LATENT", "type": "LATENT", "slot_index": 0, "links": [116]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "KSampler", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": [584027519362099, "randomize", 4, 1, "uni_pc", "simple", 1]}, {"id": 224, "type": "MarkdownNote", "pos": [420, -160], "size": [310, 110], "flags": {}, "order": 0, "mode": 0, "inputs": [], "outputs": [], "title": "About Video Size", "properties": {}, "widgets_values": ["| Model | 480P | 720P |\n| ------------------------------------------------------------ | ---- | ---- |\n| [VACE-1.3B](https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B) | ✅ | ❌ |\n| [VACE-14B](https://huggingface.co/Wan-AI/Wan2.1-VACE-14B) | ✅ | ✅ |"], "color": "#432", "bgcolor": "#000"}, {"id": 223, "type": "MarkdownNote", "pos": [770, -210], "size": [303.90106201171875, 158.5415802001953], "flags": {}, "order": 1, "mode": 0, "inputs": [], "outputs": [], "title": "KSampler Setting", "properties": {}, "widgets_values": ["## Default\n\n- steps:20\n- cfg:6.0\n\n## For CausVid LoRA\n\n- steps: 2-4\n- cfg: 1.0\n\n"], "color": "#432", "bgcolor": "#000"}, {"id": 6, "type": "CLIPTextEncode", "pos": [-80, 60], "size": [420, 280], "flags": {}, "order": 7, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 74}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [96]}], "title": "CLIP Text Encode (Positive Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": [""], "color": "#232", "bgcolor": "#353"}, {"id": 140, "type": "UNETLoader", "pos": [-505.8336486816406, 88.22794342041016], "size": [360, 82], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "unet_name", "name": "unet_name", "type": "COMBO", "widget": {"name": "unet_name"}, "link": 363}, {"localized_name": "weight_dtype", "name": "weight_dtype", "type": "COMBO", "widget": {"name": "weight_dtype"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "slot_index": 0, "links": [248]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "UNETLoader", "models": [{"name": "wan2.1_vace_14B_fp16.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/diffusion_models/wan2.1_vace_14B_fp16.safetensors", "directory": "diffusion_models"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["wan2.1_vace_14B_fp16.safetensors", "fp8_e4m3fn_fast"]}, {"id": 154, "type": "LoraLoaderModelOnly", "pos": [-505.8336486816406, 228.2279510498047], "size": [360, 85.11004638671875], "flags": {}, "order": 6, "mode": 0, "inputs": [{"localized_name": "model", "name": "model", "type": "MODEL", "link": 248}, {"localized_name": "lora_name", "name": "lora_name", "type": "COMBO", "widget": {"name": "lora_name"}, "link": 364}, {"localized_name": "strength_model", "name": "strength_model", "type": "FLOAT", "widget": {"name": "strength_model"}, "link": null}], "outputs": [{"localized_name": "MODEL", "name": "MODEL", "type": "MODEL", "links": [279]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "LoraLoaderModelOnly", "models": [{"name": "Wan21_CausVid_14B_T2V_lora_rank32.safetensors", "url": "https://huggingface.co/Kijai/WanVideo_comfy/resolve/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors", "directory": "loras"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["Wan21_CausVid_14B_T2V_lora_rank32.safetensors", 0.30000000000000004]}, {"id": 38, "type": "CLIPLoader", "pos": [-499.14141845703125, 368.0911865234375], "size": [360, 106], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "clip_name", "name": "clip_name", "type": "COMBO", "widget": {"name": "clip_name"}, "link": 365}, {"localized_name": "type", "name": "type", "type": "COMBO", "widget": {"name": "type"}, "link": null}, {"localized_name": "device", "name": "device", "shape": 7, "type": "COMBO", "widget": {"name": "device"}, "link": null}], "outputs": [{"localized_name": "CLIP", "name": "CLIP", "type": "CLIP", "slot_index": 0, "links": [74, 75]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "CLIPLoader", "models": [{"name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors?download=true", "directory": "text_encoders"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["umt5_xxl_fp8_e4m3fn_scaled.safetensors", "wan", "default"]}, {"id": 39, "type": "VAELoader", "pos": [-498.5298156738281, 517.2576293945312], "size": [360, 60], "flags": {}, "order": 4, "mode": 0, "inputs": [{"localized_name": "vae_name", "name": "vae_name", "type": "COMBO", "widget": {"name": "vae_name"}, "link": 366}], "outputs": [{"localized_name": "VAE", "name": "VAE", "type": "VAE", "slot_index": 0, "links": [76, 101]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "VAELoader", "models": [{"name": "wan_2.1_vae.safetensors", "url": "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/resolve/main/split_files/vae/wan_2.1_vae.safetensors", "directory": "vae"}], "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["wan_2.1_vae.safetensors"]}, {"id": 221, "type": "MarkdownNote", "pos": [380, 1090], "size": [480, 170], "flags": {}, "order": 5, "mode": 0, "inputs": [], "outputs": [], "title": "[EN] About video mask", "properties": {"widget_ue_connectable": {}}, "widgets_values": ["Currently, it's difficult to perfectly draw dynamic masks for different frames using only core nodes. However, to avoid requiring users to install additional custom nodes, our templates only use core nodes. You can refer to this implementation idea to achieve video inpainting.\n\nYou can use KJNode’s Points Editor and Sam2Segmentation to create some dynamic mask functions.\n\nCustom node links:\n- [ComfyUI-KJNodes](https://github.com/kijai/ComfyUI-KJNodes)\n- [ComfyUI-segment-anything-2](https://github.com/kijai/ComfyUI-segment-anything-2)"], "color": "#432", "bgcolor": "#000"}, {"id": 7, "type": "CLIPTextEncode", "pos": [-80, 390], "size": [425.27801513671875, 180.6060791015625], "flags": {}, "order": 8, "mode": 0, "inputs": [{"localized_name": "clip", "name": "clip", "type": "CLIP", "link": 75}, {"localized_name": "text", "name": "text", "type": "STRING", "widget": {"name": "text"}, "link": null}], "outputs": [{"localized_name": "CONDITIONING", "name": "CONDITIONING", "type": "CONDITIONING", "slot_index": 0, "links": [97]}], "title": "CLIP Text Encode (Negative Prompt)", "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "CLIPTextEncode", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {}}, "widgets_values": ["过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"], "color": "#223", "bgcolor": "#335"}, {"id": 229, "type": "ImageFromBatch", "pos": [-510, 800], "size": [270, 82], "flags": {}, "order": 25, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 358}, {"localized_name": "batch_index", "name": "batch_index", "type": "INT", "widget": {"name": "batch_index"}, "link": null}, {"localized_name": "length", "name": "length", "type": "INT", "widget": {"name": "length"}, "link": null}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [359, 360]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "ImageFromBatch"}, "widgets_values": [0, 81]}, {"id": 49, "type": "WanVaceToVideo", "pos": [400, 200], "size": [315, 254], "flags": {}, "order": 12, "mode": 0, "inputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "link": 96}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "link": 97}, {"localized_name": "vae", "name": "vae", "type": "VAE", "link": 101}, {"localized_name": "control_video", "name": "control_video", "shape": 7, "type": "IMAGE", "link": 344}, {"localized_name": "control_masks", "name": "control_masks", "shape": 7, "type": "MASK", "link": 349}, {"localized_name": "reference_image", "name": "reference_image", "shape": 7, "type": "IMAGE", "link": 361}, {"localized_name": "width", "name": "width", "type": "INT", "widget": {"name": "width"}, "link": 355}, {"localized_name": "height", "name": "height", "type": "INT", "widget": {"name": "height"}, "link": 356}, {"localized_name": "length", "name": "length", "type": "INT", "widget": {"name": "length"}, "link": null}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "widget": {"name": "batch_size"}, "link": null}, {"localized_name": "strength", "name": "strength", "type": "FLOAT", "widget": {"name": "strength"}, "link": null}], "outputs": [{"localized_name": "positive", "name": "positive", "type": "CONDITIONING", "links": [98]}, {"localized_name": "negative", "name": "negative", "type": "CONDITIONING", "links": [99]}, {"localized_name": "latent", "name": "latent", "type": "LATENT", "links": [160]}, {"localized_name": "trim_latent", "name": "trim_latent", "type": "INT", "links": [115]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.34", "Node name for S&R": "WanVaceToVideo", "enableTabs": false, "tabWidth": 65, "tabXOffset": 10, "hasSecondTab": false, "secondTabText": "Send Back", "secondTabOffset": 80, "secondTabWidth": 65, "widget_ue_connectable": {"width": true, "height": true, "length": true}}, "widgets_values": [720, 720, 81, 1, 1]}, {"id": 211, "type": "GetImageSize", "pos": [70, 800], "size": [190, 66], "flags": {"collapsed": false}, "order": 20, "mode": 0, "inputs": [{"localized_name": "image", "name": "image", "type": "IMAGE", "link": 359}], "outputs": [{"localized_name": "width", "name": "width", "type": "INT", "links": null}, {"localized_name": "height", "name": "height", "type": "INT", "links": null}, {"localized_name": "batch_size", "name": "batch_size", "type": "INT", "links": [340, 346]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.40", "Node name for S&R": "GetImageSize"}, "widgets_values": []}, {"id": 210, "type": "GetVideoComponents", "pos": [-510, 690], "size": [193.530859375, 66], "flags": {}, "order": 19, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 336}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [358]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [362]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [353]}], "properties": {"cnr_id": "comfy-core", "ver": "0.3.40", "Node name for S&R": "GetVideoComponents"}, "widgets_values": []}], "groups": [{"id": 1, "title": "Step1 - Load models here", "bounding": [-540, -30, 430, 620], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 2, "title": "Prompt", "bounding": [-90, -30, 450, 620], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 3, "title": "Sampling & Decoding", "bounding": [380, -30, 720, 620], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 10, "title": "Repeat Mask Batch", "bounding": [-90, 910, 450, 460], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 21, "title": "Get video info", "bounding": [-540, 610, 900, 290], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 22, "title": "Composite video & masks", "bounding": [380, 610, 720, 420], "color": "#3f789e", "font_size": 24, "flags": {}}, {"id": 23, "title": "Step4 - Set video size & length", "bounding": [390, 130, 360, 340], "color": "#A88", "font_size": 24, "flags": {}}, {"id": 25, "title": "14B", "bounding": [-520, 10, 380, 308.7100524902344], "color": "#3f789e", "font_size": 24, "flags": {}}], "links": [{"id": 116, "origin_id": 3, "origin_slot": 0, "target_id": 58, "target_slot": 0, "type": "LATENT"}, {"id": 115, "origin_id": 49, "origin_slot": 3, "target_id": 58, "target_slot": 1, "type": "INT"}, {"id": 117, "origin_id": 58, "origin_slot": 0, "target_id": 8, "target_slot": 0, "type": "LATENT"}, {"id": 76, "origin_id": 39, "origin_slot": 0, "target_id": 8, "target_slot": 1, "type": "VAE"}, {"id": 279, "origin_id": 154, "origin_slot": 0, "target_id": 48, "target_slot": 0, "type": "MODEL"}, {"id": 352, "origin_id": 219, "origin_slot": 0, "target_id": 216, "target_slot": 0, "type": "MASK"}, {"id": 340, "origin_id": 211, "origin_slot": 2, "target_id": 213, "target_slot": 1, "type": "INT"}, {"id": 96, "origin_id": 6, "origin_slot": 0, "target_id": 49, "target_slot": 0, "type": "CONDITIONING"}, {"id": 97, "origin_id": 7, "origin_slot": 0, "target_id": 49, "target_slot": 1, "type": "CONDITIONING"}, {"id": 101, "origin_id": 39, "origin_slot": 0, "target_id": 49, "target_slot": 2, "type": "VAE"}, {"id": 344, "origin_id": 208, "origin_slot": 0, "target_id": 49, "target_slot": 3, "type": "IMAGE"}, {"id": 349, "origin_id": 130, "origin_slot": 0, "target_id": 49, "target_slot": 4, "type": "MASK"}, {"id": 139, "origin_id": 8, "origin_slot": 0, "target_id": 68, "target_slot": 0, "type": "IMAGE"}, {"id": 353, "origin_id": 210, "origin_slot": 2, "target_id": 68, "target_slot": 2, "type": "FLOAT"}, {"id": 333, "origin_id": 213, "origin_slot": 0, "target_id": 208, "target_slot": 0, "type": "IMAGE"}, {"id": 334, "origin_id": 216, "origin_slot": 0, "target_id": 208, "target_slot": 1, "type": "IMAGE"}, {"id": 341, "origin_id": 208, "origin_slot": 0, "target_id": 214, "target_slot": 0, "type": "IMAGE"}, {"id": 201, "origin_id": 111, "origin_slot": 0, "target_id": 129, "target_slot": 0, "type": "IMAGE"}, {"id": 346, "origin_id": 211, "origin_slot": 2, "target_id": 129, "target_slot": 1, "type": "INT"}, {"id": 202, "origin_id": 129, "origin_slot": 0, "target_id": 130, "target_slot": 0, "type": "IMAGE"}, {"id": 280, "origin_id": 48, "origin_slot": 0, "target_id": 3, "target_slot": 0, "type": "MODEL"}, {"id": 98, "origin_id": 49, "origin_slot": 0, "target_id": 3, "target_slot": 1, "type": "CONDITIONING"}, {"id": 99, "origin_id": 49, "origin_slot": 1, "target_id": 3, "target_slot": 2, "type": "CONDITIONING"}, {"id": 160, "origin_id": 49, "origin_slot": 2, "target_id": 3, "target_slot": 3, "type": "LATENT"}, {"id": 74, "origin_id": 38, "origin_slot": 0, "target_id": 6, "target_slot": 0, "type": "CLIP"}, {"id": 248, "origin_id": 140, "origin_slot": 0, "target_id": 154, "target_slot": 0, "type": "MODEL"}, {"id": 75, "origin_id": 38, "origin_slot": 0, "target_id": 7, "target_slot": 0, "type": "CLIP"}, {"id": 351, "origin_id": -10, "origin_slot": 0, "target_id": 219, "target_slot": 0, "type": "MASK"}, {"id": 335, "origin_id": -10, "origin_slot": 0, "target_id": 208, "target_slot": 2, "type": "MASK"}, {"id": 345, "origin_id": -10, "origin_slot": 0, "target_id": 111, "target_slot": 0, "type": "MASK"}, {"id": 336, "origin_id": -10, "origin_slot": 1, "target_id": 210, "target_slot": 0, "type": "VIDEO"}, {"id": 129, "origin_id": 68, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 355, "origin_id": -10, "origin_slot": 2, "target_id": 49, "target_slot": 6, "type": "INT"}, {"id": 356, "origin_id": -10, "origin_slot": 3, "target_id": 49, "target_slot": 7, "type": "INT"}, {"id": 358, "origin_id": 210, "origin_slot": 0, "target_id": 229, "target_slot": 0, "type": "IMAGE"}, {"id": 359, "origin_id": 229, "origin_slot": 0, "target_id": 211, "target_slot": 0, "type": "IMAGE"}, {"id": 360, "origin_id": 229, "origin_slot": 0, "target_id": 213, "target_slot": 0, "type": "IMAGE"}, {"id": 361, "origin_id": -10, "origin_slot": 4, "target_id": 49, "target_slot": 5, "type": "IMAGE"}, {"id": 362, "origin_id": 210, "origin_slot": 1, "target_id": 68, "target_slot": 1, "type": "AUDIO"}, {"id": 363, "origin_id": -10, "origin_slot": 5, "target_id": 140, "target_slot": 0, "type": "COMBO"}, {"id": 364, "origin_id": -10, "origin_slot": 6, "target_id": 154, "target_slot": 1, "type": "COMBO"}, {"id": 365, "origin_id": -10, "origin_slot": 7, "target_id": 38, "target_slot": 0, "type": "COMBO"}, {"id": 366, "origin_id": -10, "origin_slot": 8, "target_id": 39, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Inpaint video"}]}, "config": {}, "extra": {"workflowRendererVersion": "LG", "ds": {"scale": 0.8183828377358485, "offset": [1215.8643989712405, 178.87024992690183]}}, "version": 0.4} diff --git a/blueprints/Video Stitch.json b/blueprints/Video Stitch.json new file mode 100644 index 000000000..11bcf6b7d --- /dev/null +++ b/blueprints/Video Stitch.json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 84, "last_link_id": 0, "nodes": [{"id": 84, "type": "8e8aa94a-647e-436d-8440-8ee4691864de", "pos": [-6100, 2620], "size": [290, 160], "flags": {}, "order": 0, "mode": 0, "inputs": [{"label": "Before Video", "localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"label": "After Video", "localized_name": "video_1", "name": "video_1", "type": "VIDEO", "link": null}, {"name": "direction", "type": "COMBO", "widget": {"name": "direction"}, "link": null}, {"name": "match_image_size", "type": "BOOLEAN", "widget": {"name": "match_image_size"}, "link": null}, {"name": "spacing_width", "type": "INT", "widget": {"name": "spacing_width"}, "link": null}, {"name": "spacing_color", "type": "COMBO", "widget": {"name": "spacing_color"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "properties": {"proxyWidgets": [["-1", "direction"], ["-1", "match_image_size"], ["-1", "spacing_width"], ["-1", "spacing_color"]], "cnr_id": "comfy-core", "ver": "0.13.0"}, "widgets_values": ["right", true, 0, "white"], "title": "Video Stitch"}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "8e8aa94a-647e-436d-8440-8ee4691864de", "version": 1, "state": {"lastGroupId": 1, "lastNodeId": 84, "lastLinkId": 262, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Stitch", "inputNode": {"id": -10, "bounding": [-6580, 2649, 143.55859375, 160]}, "outputNode": {"id": -20, "bounding": [-5720, 2659, 120, 60]}, "inputs": [{"id": "85555afe-c7a1-4f6e-b073-7c37f7bace7f", "name": "video", "type": "VIDEO", "linkIds": [253], "localized_name": "video", "label": "Before Video", "pos": [-6456.44140625, 2669]}, {"id": "022773ee-6b4f-4e3d-bead-68b3e75e2d20", "name": "video_1", "type": "VIDEO", "linkIds": [254], "localized_name": "video_1", "label": "After Video", "pos": [-6456.44140625, 2689]}, {"id": "7bcd7cbc-e918-472a-a0cf-2e0900545372", "name": "direction", "type": "COMBO", "linkIds": [259], "pos": [-6456.44140625, 2709]}, {"id": "9a00389d-c1c8-40d5-87fe-f41019b61fbc", "name": "match_image_size", "type": "BOOLEAN", "linkIds": [260], "pos": [-6456.44140625, 2729]}, {"id": "b95e0440-3ea8-4ae0-887e-12e75701042a", "name": "spacing_width", "type": "INT", "linkIds": [261], "pos": [-6456.44140625, 2749]}, {"id": "83ab9382-0a70-4169-b26a-66ab026b43c4", "name": "spacing_color", "type": "COMBO", "linkIds": [262], "pos": [-6456.44140625, 2769]}], "outputs": [{"id": "09707f43-7552-4a6e-bd23-d962d31801c2", "name": "VIDEO", "type": "VIDEO", "linkIds": [255], "localized_name": "VIDEO", "pos": [-5700, 2679]}], "widgets": [], "nodes": [{"id": 78, "type": "GetVideoComponents", "pos": [-6390, 2560], "size": [193.530859375, 66], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 254}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [249]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": null}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": null}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 77, "type": "GetVideoComponents", "pos": [-6390, 2420], "size": [193.530859375, 66], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 253}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [248]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [251]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [252]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 79, "type": "ImageStitch", "pos": [-6390, 2700], "size": [270, 150], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "image1", "name": "image1", "type": "IMAGE", "link": 248}, {"localized_name": "image2", "name": "image2", "shape": 7, "type": "IMAGE", "link": 249}, {"localized_name": "direction", "name": "direction", "type": "COMBO", "widget": {"name": "direction"}, "link": 259}, {"localized_name": "match_image_size", "name": "match_image_size", "type": "BOOLEAN", "widget": {"name": "match_image_size"}, "link": 260}, {"localized_name": "spacing_width", "name": "spacing_width", "type": "INT", "widget": {"name": "spacing_width"}, "link": 261}, {"localized_name": "spacing_color", "name": "spacing_color", "type": "COMBO", "widget": {"name": "spacing_color"}, "link": 262}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [250]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "ImageStitch"}, "widgets_values": ["right", true, 0, "white"]}, {"id": 80, "type": "CreateVideo", "pos": [-6040, 2610], "size": [270, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 250}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 251}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 252}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [255]}], "properties": {"cnr_id": "comfy-core", "ver": "0.13.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}], "groups": [], "links": [{"id": 248, "origin_id": 77, "origin_slot": 0, "target_id": 79, "target_slot": 0, "type": "IMAGE"}, {"id": 249, "origin_id": 78, "origin_slot": 0, "target_id": 79, "target_slot": 1, "type": "IMAGE"}, {"id": 250, "origin_id": 79, "origin_slot": 0, "target_id": 80, "target_slot": 0, "type": "IMAGE"}, {"id": 251, "origin_id": 77, "origin_slot": 1, "target_id": 80, "target_slot": 1, "type": "AUDIO"}, {"id": 252, "origin_id": 77, "origin_slot": 2, "target_id": 80, "target_slot": 2, "type": "FLOAT"}, {"id": 253, "origin_id": -10, "origin_slot": 0, "target_id": 77, "target_slot": 0, "type": "VIDEO"}, {"id": 254, "origin_id": -10, "origin_slot": 1, "target_id": 78, "target_slot": 0, "type": "VIDEO"}, {"id": 255, "origin_id": 80, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 259, "origin_id": -10, "origin_slot": 2, "target_id": 79, "target_slot": 2, "type": "COMBO"}, {"id": 260, "origin_id": -10, "origin_slot": 3, "target_id": 79, "target_slot": 3, "type": "BOOLEAN"}, {"id": 261, "origin_id": -10, "origin_slot": 4, "target_id": 79, "target_slot": 4, "type": "INT"}, {"id": 262, "origin_id": -10, "origin_slot": 5, "target_id": 79, "target_slot": 5, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video Tools/Stitch videos"}]}} diff --git a/blueprints/Video Upscale(GAN x4).json b/blueprints/Video Upscale(GAN x4).json new file mode 100644 index 000000000..e80b2e229 --- /dev/null +++ b/blueprints/Video Upscale(GAN x4).json @@ -0,0 +1 @@ +{"revision": 0, "last_node_id": 13, "last_link_id": 0, "nodes": [{"id": 13, "type": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "pos": [1120, 330], "size": [240, 58], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": null}, {"name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": null}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": []}], "title": "Video Upscale(GAN x4)", "properties": {"proxyWidgets": [["-1", "model_name"]], "cnr_id": "comfy-core", "ver": "0.14.1"}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "links": [], "version": 0.4, "definitions": {"subgraphs": [{"id": "cf95b747-3e17-46cb-8097-cac60ff9b2e1", "version": 1, "state": {"lastGroupId": 0, "lastNodeId": 13, "lastLinkId": 19, "lastRerouteId": 0}, "revision": 0, "config": {}, "name": "Video Upscale(GAN x4)", "inputNode": {"id": -10, "bounding": [550, 460, 120, 80]}, "outputNode": {"id": -20, "bounding": [1490, 460, 120, 60]}, "inputs": [{"id": "666d633e-93e7-42dc-8d11-2b7b99b0f2a6", "name": "video", "type": "VIDEO", "linkIds": [10], "localized_name": "video", "pos": [650, 480]}, {"id": "2e23a087-caa8-4d65-99e6-662761aa905a", "name": "model_name", "type": "COMBO", "linkIds": [19], "pos": [650, 500]}], "outputs": [{"id": "0c1768ea-3ec2-412f-9af6-8e0fa36dae70", "name": "VIDEO", "type": "VIDEO", "linkIds": [15], "localized_name": "VIDEO", "pos": [1510, 480]}], "widgets": [], "nodes": [{"id": 2, "type": "ImageUpscaleWithModel", "pos": [1110, 450], "size": [320, 46], "flags": {}, "order": 1, "mode": 0, "inputs": [{"localized_name": "upscale_model", "name": "upscale_model", "type": "UPSCALE_MODEL", "link": 1}, {"localized_name": "image", "name": "image", "type": "IMAGE", "link": 14}], "outputs": [{"localized_name": "IMAGE", "name": "IMAGE", "type": "IMAGE", "links": [13]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "ImageUpscaleWithModel"}}, {"id": 11, "type": "CreateVideo", "pos": [1110, 550], "size": [320, 78], "flags": {}, "order": 3, "mode": 0, "inputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "link": 13}, {"localized_name": "audio", "name": "audio", "shape": 7, "type": "AUDIO", "link": 16}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "widget": {"name": "fps"}, "link": 12}], "outputs": [{"localized_name": "VIDEO", "name": "VIDEO", "type": "VIDEO", "links": [15]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "CreateVideo"}, "widgets_values": [30]}, {"id": 10, "type": "GetVideoComponents", "pos": [1110, 330], "size": [320, 70], "flags": {}, "order": 2, "mode": 0, "inputs": [{"localized_name": "video", "name": "video", "type": "VIDEO", "link": 10}], "outputs": [{"localized_name": "images", "name": "images", "type": "IMAGE", "links": [14]}, {"localized_name": "audio", "name": "audio", "type": "AUDIO", "links": [16]}, {"localized_name": "fps", "name": "fps", "type": "FLOAT", "links": [12]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "GetVideoComponents"}}, {"id": 1, "type": "UpscaleModelLoader", "pos": [750, 450], "size": [280, 60], "flags": {}, "order": 0, "mode": 0, "inputs": [{"localized_name": "model_name", "name": "model_name", "type": "COMBO", "widget": {"name": "model_name"}, "link": 19}], "outputs": [{"localized_name": "UPSCALE_MODEL", "name": "UPSCALE_MODEL", "type": "UPSCALE_MODEL", "links": [1]}], "properties": {"cnr_id": "comfy-core", "ver": "0.10.0", "Node name for S&R": "UpscaleModelLoader", "models": [{"name": "RealESRGAN_x4plus.safetensors", "url": "https://huggingface.co/Comfy-Org/Real-ESRGAN_repackaged/resolve/main/RealESRGAN_x4plus.safetensors", "directory": "upscale_models"}]}, "widgets_values": ["RealESRGAN_x4plus.safetensors"]}], "groups": [], "links": [{"id": 1, "origin_id": 1, "origin_slot": 0, "target_id": 2, "target_slot": 0, "type": "UPSCALE_MODEL"}, {"id": 14, "origin_id": 10, "origin_slot": 0, "target_id": 2, "target_slot": 1, "type": "IMAGE"}, {"id": 13, "origin_id": 2, "origin_slot": 0, "target_id": 11, "target_slot": 0, "type": "IMAGE"}, {"id": 16, "origin_id": 10, "origin_slot": 1, "target_id": 11, "target_slot": 1, "type": "AUDIO"}, {"id": 12, "origin_id": 10, "origin_slot": 2, "target_id": 11, "target_slot": 2, "type": "FLOAT"}, {"id": 10, "origin_id": -10, "origin_slot": 0, "target_id": 10, "target_slot": 0, "type": "VIDEO"}, {"id": 15, "origin_id": 11, "origin_slot": 0, "target_id": -20, "target_slot": 0, "type": "VIDEO"}, {"id": 19, "origin_id": -10, "origin_slot": 1, "target_id": 1, "target_slot": 0, "type": "COMBO"}], "extra": {"workflowRendererVersion": "LG"}, "category": "Video generation and editing/Enhance video"}]}, "extra": {}} diff --git a/blueprints/put_blueprints_here b/blueprints/put_blueprints_here new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/audio_encoders/audio_encoders.py b/comfy/audio_encoders/audio_encoders.py index 46ef21c95..0de7584b0 100644 --- a/comfy/audio_encoders/audio_encoders.py +++ b/comfy/audio_encoders/audio_encoders.py @@ -25,11 +25,12 @@ class AudioEncoderModel(): elif model_type == "whisper3": self.model = WhisperLargeV3(**model_config) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.model_sample_rate = 16000 + comfy.model_management.archive_model_dtypes(self.model) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() diff --git a/comfy/checkpoint_pickle.py b/comfy/checkpoint_pickle.py deleted file mode 100644 index 206551d3c..000000000 --- a/comfy/checkpoint_pickle.py +++ /dev/null @@ -1,13 +0,0 @@ -import pickle - -load = pickle.load - -class Empty: - pass - -class Unpickler(pickle.Unpickler): - def find_class(self, module, name): - #TODO: safe unpickle - if module.startswith("pytorch_lightning"): - return Empty - return super().find_class(module, name) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0b51e3c07..551ac4149 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -84,6 +84,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") +parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.") + parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") @@ -147,6 +149,8 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") +parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") +parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") @@ -232,6 +236,7 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") +parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") if comfy.options.args_parsing: args = parser.parse_args() @@ -257,3 +262,8 @@ elif args.fast == []: # '--fast' is provided with a list of performance features, use that list else: args.fast = set(args.fast) + +def enables_dynamic_vram(): + if args.enable_dynamic_vram: + return True + return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7c0cadab5..d7d3f994c 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -1,6 +1,59 @@ import torch from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.ops +import math + +def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) + if not (image.shape[2] == size and image.shape[3] == size): + if crop: + scale = (size / min(image.shape[2], image.shape[3])) + scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) + else: + scale_size = (size, size) + + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + h = (image.shape[2] - size)//2 + w = (image.shape[3] - size)//2 + image = image[:,:,h:h+size,w:w+size] + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3,1,1])) / std.view([3,1,1]) + +def siglip2_flex_calc_resolution(oh, ow, patch_size, max_num_patches, eps=1e-5): + def scale_dim(size, scale): + scaled = math.ceil(size * scale / patch_size) * patch_size + return max(patch_size, int(scaled)) + + # Binary search for optimal scale + lo, hi = eps / 10, 100.0 + while hi - lo >= eps: + mid = (lo + hi) / 2 + h, w = scale_dim(oh, mid), scale_dim(ow, mid) + if (h // patch_size) * (w // patch_size) <= max_num_patches: + lo = mid + else: + hi = mid + + return scale_dim(oh, lo), scale_dim(ow, lo) + +def siglip2_preprocess(image, size, patch_size, num_patches, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True): + if size > 0: + return clip_preprocess(image, size=size, mean=mean, std=std, crop=crop) + + image = image[:, :, :, :3] if image.shape[3] > 3 else image + mean = torch.tensor(mean, device=image.device, dtype=image.dtype) + std = torch.tensor(std, device=image.device, dtype=image.dtype) + image = image.movedim(-1, 1) + + b, c, h, w = image.shape + h, w = siglip2_flex_calc_resolution(h, w, patch_size, num_patches) + + image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear", antialias=True) + image = torch.clip((255. * image), 0, 255).round() / 255.0 + return (image - mean.view([3, 1, 1])) / std.view([3, 1, 1]) class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): @@ -156,6 +209,27 @@ class CLIPTextModel(torch.nn.Module): out = self.text_projection(x[2]) return (x[0], x[1], out, x[2]) +def siglip2_pos_embed(embed_weight, embeds, orig_shape): + embed_weight_len = round(embed_weight.shape[0] ** 0.5) + embed_weight = comfy.ops.cast_to_input(embed_weight, embeds).movedim(1, 0).reshape(1, -1, embed_weight_len, embed_weight_len) + embed_weight = torch.nn.functional.interpolate(embed_weight, size=orig_shape, mode="bilinear", align_corners=False, antialias=True) + embed_weight = embed_weight.reshape(-1, embed_weight.shape[-2] * embed_weight.shape[-1]).movedim(0, 1) + return embeds + embed_weight + +class Siglip2Embeddings(torch.nn.Module): + def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", num_patches=None, dtype=None, device=None, operations=None): + super().__init__() + self.patch_embedding = operations.Linear(num_channels * patch_size * patch_size, embed_dim, dtype=dtype, device=device) + self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device) + self.patch_size = patch_size + + def forward(self, pixel_values): + b, c, h, w = pixel_values.shape + img = pixel_values.movedim(1, -1).reshape(b, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size, c) + img = img.permute(0, 1, 3, 2, 4, 5) + img = img.reshape(b, img.shape[1] * img.shape[2], -1) + img = self.patch_embedding(img) + return siglip2_pos_embed(self.position_embedding.weight, img, (h // self.patch_size, w // self.patch_size)) class CLIPVisionEmbeddings(torch.nn.Module): def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None): @@ -199,8 +273,11 @@ class CLIPVision(torch.nn.Module): intermediate_activation = config_dict["hidden_act"] model_type = config_dict["model_type"] - self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) - if model_type == "siglip_vision_model": + if model_type in ["siglip2_vision_model"]: + self.embeddings = Siglip2Embeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, num_patches=config_dict.get("num_patches", None), dtype=dtype, device=device, operations=operations) + else: + self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) + if model_type in ["siglip_vision_model", "siglip2_vision_model"]: self.pre_layrnorm = lambda a: a self.output_layernorm = True else: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 447b1ce4a..1691fca81 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,6 +1,5 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace import os -import torch import json import logging @@ -17,28 +16,12 @@ class Output: def __setitem__(self, key, item): setattr(self, key, item) -def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True): - image = image[:, :, :, :3] if image.shape[3] > 3 else image - mean = torch.tensor(mean, device=image.device, dtype=image.dtype) - std = torch.tensor(std, device=image.device, dtype=image.dtype) - image = image.movedim(-1, 1) - if not (image.shape[2] == size and image.shape[3] == size): - if crop: - scale = (size / min(image.shape[2], image.shape[3])) - scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3])) - else: - scale_size = (size, size) - - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) - h = (image.shape[2] - size)//2 - w = (image.shape[3] - size)//2 - image = image[:,:,h:h+size,w:w+size] - image = torch.clip((255. * image), 0, 255).round() / 255.0 - return (image - mean.view([3,1,1])) / std.view([3,1,1]) +clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually IMAGE_ENCODERS = { "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, + "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, } @@ -50,9 +33,10 @@ class ClipVisionModel(): self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) - model_type = config.get("model_type", "clip_vision_model") - model_class = IMAGE_ENCODERS.get(model_type) - if model_type == "siglip_vision_model": + self.model_type = config.get("model_type", "clip_vision_model") + self.config = config.copy() + model_class = IMAGE_ENCODERS.get(self.model_type) + if self.model_type == "siglip_vision_model": self.return_all_hidden_states = True else: self.return_all_hidden_states = False @@ -63,22 +47,26 @@ class ClipVisionModel(): self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() - self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=False) + return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() def encode_image(self, image, crop=True): comfy.model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() + if self.model_type == "siglip2_vision_model": + pixel_values = comfy.clip_model.siglip2_preprocess(image.to(self.load_device), size=self.image_size, patch_size=self.config.get("patch_size", 16), num_patches=self.config.get("num_patches", 256), mean=self.image_mean, std=self.image_std, crop=crop).float() + else: + pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float() out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2) outputs = Output() outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device()) outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device()) + outputs["image_sizes"] = [pixel_values.shape[1:]] * pixel_values.shape[0] if self.return_all_hidden_states: all_hs = out[1].to(comfy.model_management.intermediate_device()) outputs["penultimate_hidden_states"] = all_hs[:, -2] @@ -125,10 +113,14 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0] if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152: - if embed_shape == 729: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") - elif embed_shape == 1024: - json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") + patch_embedding_shape = sd["vision_model.embeddings.patch_embedding.weight"].shape + if len(patch_embedding_shape) == 2: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip2_base_naflex.json") + else: + if embed_shape == 729: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json") + elif embed_shape == 1024: + json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json") elif embed_shape == 577: if "multi_modal_projector.linear_1.bias" in sd: json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json") diff --git a/comfy/clip_vision_siglip2_base_naflex.json b/comfy/clip_vision_siglip2_base_naflex.json new file mode 100644 index 000000000..6f6b99bd6 --- /dev/null +++ b/comfy/clip_vision_siglip2_base_naflex.json @@ -0,0 +1,14 @@ +{ + "num_channels": 3, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": -1, + "intermediate_size": 4304, + "model_type": "siglip2_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 16, + "num_patches": 256, + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5] +} diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py index 071b98332..57126fa4a 100644 --- a/comfy/comfy_types/node_typing.py +++ b/comfy/comfy_types/node_typing.py @@ -176,6 +176,8 @@ class InputTypeOptions(TypedDict): """COMBO type only. Specifies the configuration for a multi-select widget. Available after ComfyUI frontend v1.13.4 https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987""" + gradient_stops: NotRequired[list[dict]] + """Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}.""" class HiddenInputTypeDict(TypedDict): @@ -236,6 +238,8 @@ class ComfyNodeABC(ABC): """Flags a node as experimental, informing users that it may change or not work as expected.""" DEPRECATED: bool """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" + DEV_ONLY: bool + """Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled.""" API_NODE: Optional[bool] """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" diff --git a/comfy/conds.py b/comfy/conds.py index 5af3e93ea..55d8cdd78 100644 --- a/comfy/conds.py +++ b/comfy/conds.py @@ -4,6 +4,25 @@ import comfy.utils import logging +def is_equal(x, y): + if torch.is_tensor(x) and torch.is_tensor(y): + return torch.equal(x, y) + elif isinstance(x, dict) and isinstance(y, dict): + if x.keys() != y.keys(): + return False + return all(is_equal(x[k], y[k]) for k in x) + elif isinstance(x, (list, tuple)) and isinstance(y, (list, tuple)): + if type(x) is not type(y) or len(x) != len(y): + return False + return all(is_equal(a, b) for a, b in zip(x, y)) + else: + try: + return x == y + except Exception: + logging.warning("comparison issue with COND") + return False + + class CONDRegular: def __init__(self, cond): self.cond = cond @@ -84,7 +103,7 @@ class CONDConstant(CONDRegular): return self._copy_with(self.cond) def can_concat(self, other): - if self.cond != other.cond: + if not is_equal(self.cond, other.cond): return False return True diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 2979b3ca1..cb44ee6e8 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -93,6 +93,50 @@ class IndexListCallbacks: return {} +def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]): + if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)): + return None + cond_tensor = cond_value.cond + if temporal_dim >= cond_tensor.ndim: + return None + + cond_size = cond_tensor.size(temporal_dim) + + if temporal_scale == 1: + expected_size = x_in.size(window.dim) - temporal_offset + if cond_size != expected_size: + return None + + if temporal_offset == 0 and temporal_scale == 1: + sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list) + return cond_value._copy_with(sliced) + + # skip leading latent positions that have no corresponding conditioning (e.g. reference frames) + if temporal_offset > 0: + indices = [i - temporal_offset for i in window.index_list[temporal_offset:]] + indices = [i for i in indices if 0 <= i] + else: + indices = list(window.index_list) + + if not indices: + return None + + if temporal_scale > 1: + scaled = [] + for i in indices: + for k in range(temporal_scale): + si = i * temporal_scale + k + if si < cond_size: + scaled.append(si) + indices = scaled + if not indices: + return None + + idx = tuple([slice(None)] * temporal_dim + [indices]) + sliced = cond_tensor[idx].to(device) + return cond_value._copy_with(sliced) + + @dataclass class ContextSchedule: name: str @@ -143,7 +187,7 @@ class IndexListContextHandler(ContextHandlerABC): # if multiple conds, split based on primary region if self.split_conds_to_windows and len(cond_in) > 1: region = window.get_region_index(len(cond_in)) - logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + logging.info(f"Splitting conds to windows; using region {region} for window {window.index_list[0]}-{window.index_list[-1]} with center ratio {window.center_ratio:.3f}") cond_in = [cond_in[region]] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: @@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC): new_cond_item[cond_key] = result handled = True break + if not handled and self._model is not None: + result = self._model.resize_cond_for_context_window( + cond_key, cond_value, window, x_in, device, + retain_index_list=self.cond_retain_index_list) + if result is not None: + new_cond_item[cond_key] = result + handled = True if handled: continue if isinstance(cond_value, torch.Tensor): - if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \ (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): new_cond_item[cond_key] = window.get_tensor(cond_value, device) # Handle audio_embed (temporal dim is 1) @@ -188,6 +239,12 @@ class IndexListContextHandler(ContextHandlerABC): audio_cond = cond_value.cond if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) + # Handle vace_context (temporal dim is 3) + elif cond_key == "vace_context" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + vace_cond = cond_value.cond + if vace_cond.ndim >= 4 and vace_cond.size(3) == x_in.size(self.dim): + sliced_vace = window.get_tensor(vace_cond, device, dim=3, retain_index_list=self.cond_retain_index_list) + new_cond_item[cond_key] = cond_value._copy_with(sliced_vace) # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ @@ -208,7 +265,7 @@ class IndexListContextHandler(ContextHandlerABC): mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: - raise Exception("No sample_sigmas matched current timestep; something went wrong.") + return # substep from multi-step sampler: keep self._step from the last full step self._step = int(matches[0].item()) def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: @@ -218,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC): return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + self._model = model self.set_step(timestep, model_options) context_windows = self.get_context_windows(model, x_in, model_options) enumerated_context_windows = list(enumerate(context_windows)) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 0b5e30f52..ba670b16d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -203,7 +203,7 @@ class ControlNet(ControlBase): self.control_model = control_model self.load_device = load_device if control_model is not None: - self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) + self.control_model_wrapped = comfy.model_patcher.CoreModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device()) self.compression_ratio = compression_ratio self.global_average_pooling = global_average_pooling @@ -297,6 +297,30 @@ class ControlNet(ControlBase): self.model_sampling_current = None super().cleanup() + +class QwenFunControlNet(ControlNet): + def get_control(self, x_noisy, t, cond, batched_number, transformer_options): + # Fun checkpoints are more sensitive to high strengths in the generic + # ControlNet merge path. Use a soft response curve so strength=1.0 stays + # unchanged while >1 grows more gently. + original_strength = self.strength + self.strength = math.sqrt(max(self.strength, 0.0)) + try: + return super().get_control(x_noisy, t, cond, batched_number, transformer_options) + finally: + self.strength = original_strength + + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + self.set_extra_arg("base_model", model.diffusion_model) + + def copy(self): + c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped + self.copy_to(c) + return c + class ControlLoraOps: class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, @@ -560,6 +584,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}): def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + sd = model_config.process_unet_state_dict(sd) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) @@ -605,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control + +def load_controlnet_qwen_fun(sd, model_options={}): + load_device = comfy.model_management.get_torch_device() + weight_dtype = comfy.utils.weight_dtype(sd) + unet_dtype = model_options.get("dtype", weight_dtype) + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + + operations = model_options.get("custom_operations", None) + if operations is None: + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True) + + in_features = sd["control_img_in.weight"].shape[1] + inner_dim = sd["control_img_in.weight"].shape[0] + + block_weight = sd["control_blocks.0.attn.to_q.weight"] + attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0] + num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim)) + + model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel( + control_in_features=in_features, + inner_dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_control_blocks=5, + main_model_double=60, + injection_layers=(0, 12, 24, 36, 48), + operations=operations, + device=comfy.model_management.unet_offload_device(), + dtype=unet_dtype, + ) + model = controlnet_load_state_dict(model, sd) + + latent_format = comfy.latent_formats.Wan21() + control = QwenFunControlNet( + model, + compression_ratio=1, + latent_format=latent_format, + # Fun checkpoints already expect their own 33-channel context handling. + # Enabling generic concat_mask injects an extra mask channel at apply-time + # and breaks the intended fallback packing path. + concat_mask=False, + load_device=load_device, + manual_cast_dtype=manual_cast_dtype, + extra_conds=[], + ) + return control + def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) @@ -682,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options) elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) + elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data: + return load_controlnet_qwen_fun(controlnet_data, model_options=model_options) elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) diff --git a/comfy/float.py b/comfy/float.py index 521316fd2..184b3d6d0 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -65,3 +65,183 @@ def stochastic_rounding(value, dtype, seed=0): return output return value.to(dtype=dtype) + + +# TODO: improve this? +def stochastic_float_to_fp4_e2m1(x, generator): + orig_shape = x.shape + sign = torch.signbit(x).to(torch.uint8) + + exp = torch.floor(torch.log2(x.abs()) + 1.0).clamp(0, 3) + x += (torch.rand(x.size(), dtype=x.dtype, layout=x.layout, device=x.device, generator=generator) - 0.5) * (2 ** (exp - 2.0)) * 1.25 + + x = x.abs() + exp = torch.floor(torch.log2(x) + 1.1925).clamp(0, 3) + + mantissa = torch.where( + exp > 0, + (x / (2.0 ** (exp - 1)) - 1.0) * 2.0, + (x * 2.0), + out=x + ).round().to(torch.uint8) + del x + + exp = exp.to(torch.uint8) + + fp4 = (sign << 3) | (exp << 1) | mantissa + del sign, exp, mantissa + + fp4_flat = fp4.view(-1) + packed = (fp4_flat[0::2] << 4) | fp4_flat[1::2] + return packed.reshape(list(orig_shape)[:-1] + [-1]) + + +def to_blocked(input_matrix, flatten: bool = True) -> torch.Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + + def ceil_div(a, b): + return (a + b - 1) // b + + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Calculate the padded shape + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix + + # Rearrange the blocks + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + if flatten: + return rearranged.flatten() + + return rearranged.reshape(padded_rows, padded_cols) + + +def stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator): + F4_E2M1_MAX = 6.0 + F8_E4M3_MAX = 448.0 + + orig_shape = x.shape + + block_size = 16 + + x = x.reshape(orig_shape[0], -1, block_size) + scaled_block_scales_fp8 = torch.clamp(((torch.amax(torch.abs(x), dim=-1)) / F4_E2M1_MAX) / per_tensor_scale.to(x.dtype), max=F8_E4M3_MAX).to(torch.float8_e4m3fn) + x = x / (per_tensor_scale.to(x.dtype) * scaled_block_scales_fp8.to(x.dtype)).unsqueeze(-1) + + x = x.view(orig_shape).nan_to_num() + data_lp = stochastic_float_to_fp4_e2m1(x, generator=generator) + return data_lp, scaled_block_scales_fp8 + + +def stochastic_round_quantize_nvfp4(x, per_tensor_scale, pad_16x, seed=0): + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + x, blocked_scaled = stochastic_round_quantize_nvfp4_block(x, per_tensor_scale, generator) + return x, to_blocked(blocked_scaled, flatten=False) + + +def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=0, block_size=4096 * 4096): + def roundup(x: int, multiple: int) -> int: + """Round up x to the nearest multiple.""" + return ((x + multiple - 1) // multiple) * multiple + + orig_shape = x.shape + + # Handle padding + if pad_16x: + rows, cols = x.shape + padded_rows = roundup(rows, 16) + padded_cols = roundup(cols, 16) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + # Note: We update orig_shape because the output tensor logic below assumes x.shape matches + # what we want to produce. If we pad here, we want the padded output. + orig_shape = x.shape + + orig_shape = list(orig_shape) + + output_fp4 = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 2], dtype=torch.uint8, device=x.device) + output_block = torch.empty(orig_shape[:-1] + [orig_shape[-1] // 16], dtype=torch.float8_e4m3fn, device=x.device) + + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + num_slices = max(1, (x.numel() / block_size)) + slice_size = max(1, (round(x.shape[0] / num_slices))) + + for i in range(0, x.shape[0], slice_size): + fp4, block = stochastic_round_quantize_nvfp4_block(x[i: i + slice_size], per_tensor_scale, generator=generator) + output_fp4[i:i + slice_size].copy_(fp4) + output_block[i:i + slice_size].copy_(block) + + return output_fp4, to_blocked(output_block, flatten=False) + + +def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0): + def roundup(x_val, multiple): + return ((x_val + multiple - 1) // multiple) * multiple + + if pad_32x: + rows, cols = x.shape + padded_rows = roundup(rows, 32) + padded_cols = roundup(cols, 32) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + F8_E4M3_MAX = 448.0 + E8M0_BIAS = 127 + BLOCK_SIZE = 32 + + rows, cols = x.shape + x_blocked = x.reshape(rows, -1, BLOCK_SIZE) + max_abs = torch.amax(torch.abs(x_blocked), dim=-1) + + # E8M0 block scales (power-of-2 exponents) + scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127)) + exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254) + block_scales_e8m0 = exp_biased.to(torch.uint8) + + zero_mask = (max_abs == 0) + block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32) + block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32) + + # Scale per-block then stochastic round + data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols) + output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed) + + block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0) + return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu) diff --git a/comfy/hooks.py b/comfy/hooks.py index 9d0731072..1a76c7ba4 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -527,7 +527,8 @@ class HookKeyframeGroup: if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0: break # if eval_c is outside the percent range, stop looking further - else: break + else: + break # update steps current context is used self._current_used_steps += 1 # update current timestep this was performed on diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 1ba9edad7..6978eb717 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -5,7 +5,7 @@ from scipy import integrate import torch from torch import nn import torchsde -from tqdm.auto import trange, tqdm +from tqdm.auto import tqdm from . import utils from . import deis @@ -13,6 +13,9 @@ from . import sa_solver import comfy.model_patcher import comfy.model_sampling +import comfy.memory_management +from comfy.utils import model_trange as trange + def append_zero(x): return torch.cat([x, x.new_zeros([1])]) @@ -74,6 +77,9 @@ def get_ancestral_step(sigma_from, sigma_to, eta=1.): def default_noise_sampler(x, seed=None): if seed is not None: + if x.device == torch.device("cpu"): + seed += 1 + generator = torch.Generator(device=x.device) generator.manual_seed(seed) else: diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f1ca0151e..6a57bca1c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -8,6 +8,7 @@ class LatentFormat: latent_rgb_factors_bias = None latent_rgb_factors_reshape = None taesd_decoder_name = None + spacial_downscale_ratio = 8 def process_in(self, latent): return latent * self.scale_factor @@ -80,6 +81,7 @@ class SD_X4(LatentFormat): class SC_Prior(LatentFormat): latent_channels = 16 + spacial_downscale_ratio = 42 def __init__(self): self.scale_factor = 1.0 self.latent_rgb_factors = [ @@ -102,6 +104,7 @@ class SC_Prior(LatentFormat): ] class SC_B(LatentFormat): + spacial_downscale_ratio = 4 def __init__(self): self.scale_factor = 1.0 / 0.43 self.latent_rgb_factors = [ @@ -181,6 +184,7 @@ class Flux(SD3): class Flux2(LatentFormat): latent_channels = 128 + spacial_downscale_ratio = 16 def __init__(self): self.latent_rgb_factors =[ @@ -272,6 +276,7 @@ class Mochi(LatentFormat): class LTXV(LatentFormat): latent_channels = 128 latent_dimensions = 3 + spacial_downscale_ratio = 32 def __init__(self): self.latent_rgb_factors = [ @@ -407,6 +412,11 @@ class LTXV(LatentFormat): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] +class LTXAV(LTXV): + def __init__(self): + self.latent_rgb_factors = None + self.latent_rgb_factors_bias = None + class HunyuanVideo(LatentFormat): latent_channels = 16 latent_dimensions = 3 @@ -510,6 +520,7 @@ class Wan21(LatentFormat): class Wan22(Wan21): latent_channels = 48 latent_dimensions = 3 + spacial_downscale_ratio = 16 latent_rgb_factors = [ [ 0.0119, 0.0103, 0.0046], @@ -587,6 +598,7 @@ class Wan22(Wan21): class HunyuanImage21(LatentFormat): latent_channels = 64 latent_dimensions = 2 + spacial_downscale_ratio = 32 scale_factor = 0.75289 latent_rgb_factors = [ @@ -720,6 +732,7 @@ class HunyuanVideo15(LatentFormat): latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644] latent_channels = 32 latent_dimensions = 3 + spacial_downscale_ratio = 16 scale_factor = 1.03682 taesd_decoder_name = "lighttaehy1_5" @@ -742,8 +755,13 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 +class ACEAudio15(LatentFormat): + latent_channels = 64 + latent_dimensions = 1 + class ChromaRadiance(LatentFormat): latent_channels = 3 + spacial_downscale_ratio = 1 def __init__(self): self.latent_rgb_factors = [ @@ -758,3 +776,10 @@ class ChromaRadiance(LatentFormat): def process_out(self, latent): return latent + + +class ZImagePixelSpace(ChromaRadiance): + """Pixel-space latent format for ZImage DCT variant. + No VAE encoding/decoding — the model operates directly on RGB pixels. + """ + pass diff --git a/comfy/ldm/ace/ace_step15.py b/comfy/ldm/ace/ace_step15.py new file mode 100644 index 000000000..1d7dc59a8 --- /dev/null +++ b/comfy/ldm/ace/ace_step15.py @@ -0,0 +1,1155 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import itertools +from comfy.ldm.modules.attention import optimized_attention +import comfy.model_management +from comfy.ldm.flux.layers import timestep_embedding + +def get_silence_latent(length, device): + head = torch.tensor([[[ 0.5707, 0.0982, 0.6909, -0.5658, 0.6266, 0.6996, -0.1365, -0.1291, + -0.0776, -0.1171, -0.2743, -0.8422, -0.1168, 1.5539, -4.6936, 0.7436, + -1.1846, -0.2637, 0.6933, -6.7266, 0.0966, -0.1187, -0.3501, -1.1736, + 0.0587, -2.0517, -1.3651, 0.7508, -0.2490, -1.3548, -0.1290, -0.7261, + 1.1132, -0.3249, 0.2337, 0.3004, 0.6605, -0.0298, -0.1989, -0.4041, + 0.2843, -1.0963, -0.5519, 0.2639, -1.0436, -0.1183, 0.0640, 0.4460, + -1.1001, -0.6172, -1.3241, 1.1379, 0.5623, -0.1507, -0.1963, -0.4742, + -2.4697, 0.5302, 0.5381, 0.4636, -0.1782, -0.0687, 1.0333, 0.4202], + [ 0.3040, -0.1367, 0.6200, 0.0665, -0.0642, 0.4655, -0.1187, -0.0440, + 0.2941, -0.2753, 0.0173, -0.2421, -0.0147, 1.5603, -2.7025, 0.7907, + -0.9736, -0.0682, 0.1294, -5.0707, -0.2167, 0.3302, -0.1513, -0.8100, + -0.3894, -0.2884, -0.3149, 0.8660, -0.3817, -1.7061, 0.5824, -0.4840, + 0.6938, 0.1859, 0.1753, 0.3081, 0.0195, 0.1403, -0.0754, -0.2091, + 0.1251, -0.1578, -0.4968, -0.1052, -0.4554, -0.0320, 0.1284, 0.4974, + -1.1889, -0.0344, -0.8313, 0.2953, 0.5445, -0.6249, -0.1595, -0.0682, + -3.1412, 0.0484, 0.4153, 0.8260, -0.1526, -0.0625, 0.5366, 0.8473], + [ 5.3524e-02, -1.7534e-01, 5.4443e-01, -4.3501e-01, -2.1317e-03, + 3.7200e-01, -4.0143e-03, -1.5516e-01, -1.2968e-01, -1.5375e-01, + -7.7107e-02, -2.0593e-01, -3.2780e-01, 1.5142e+00, -2.6101e+00, + 5.8698e-01, -1.2716e+00, -2.4773e-01, -2.7933e-02, -5.0799e+00, + 1.1601e-01, 4.0987e-01, -2.2030e-02, -6.6495e-01, -2.0995e-01, + -6.3474e-01, -1.5893e-01, 8.2745e-01, -2.2992e-01, -1.6816e+00, + 5.4440e-01, -4.9579e-01, 5.5128e-01, 3.0477e-01, 8.3052e-02, + -6.1782e-02, 5.9036e-03, 2.9553e-01, -8.0645e-02, -1.0060e-01, + 1.9144e-01, -3.8124e-01, -7.2949e-01, 2.4520e-02, -5.0814e-01, + 2.3977e-01, 9.2943e-02, 3.9256e-01, -1.1993e+00, -3.2752e-01, + -7.2707e-01, 2.9476e-01, 4.3542e-01, -8.8597e-01, -4.1686e-01, + -8.5390e-02, -2.9018e+00, 6.4988e-02, 5.3945e-01, 9.1988e-01, + 5.8762e-02, -7.0098e-02, 6.4772e-01, 8.9118e-01], + [-3.2225e-02, -1.3195e-01, 5.6411e-01, -5.4766e-01, -5.2170e-03, + 3.1425e-01, -5.4367e-02, -1.9419e-01, -1.3059e-01, -1.3660e-01, + -9.0984e-02, -1.9540e-01, -2.5590e-01, 1.5440e+00, -2.6349e+00, + 6.8273e-01, -1.2532e+00, -1.9810e-01, -2.2793e-02, -5.0506e+00, + 1.8818e-01, 5.0109e-01, 7.3546e-03, -6.8771e-01, -3.0676e-01, + -7.3257e-01, -1.6687e-01, 9.2232e-01, -1.8987e-01, -1.7267e+00, + 5.3355e-01, -5.3179e-01, 4.4953e-01, 2.8820e-01, 1.3012e-01, + -2.0943e-01, -1.1348e-01, 3.3929e-01, -1.5069e-01, -1.2919e-01, + 1.8929e-01, -3.6166e-01, -8.0756e-01, 6.6387e-02, -5.8867e-01, + 1.6978e-01, 1.0134e-01, 3.3877e-01, -1.2133e+00, -3.2492e-01, + -8.1237e-01, 3.8101e-01, 4.3765e-01, -8.0596e-01, -4.4531e-01, + -4.7513e-02, -2.9266e+00, 1.1741e-03, 4.5123e-01, 9.3075e-01, + 5.3688e-02, -1.9621e-01, 6.4530e-01, 9.3870e-01]]], device=device).movedim(-1, 1) + + silence_latent = torch.tensor([[[-1.3672e-01, -1.5820e-01, 5.8594e-01, -5.7422e-01, 3.0273e-02, + 2.7930e-01, -2.5940e-03, -2.0703e-01, -1.6113e-01, -1.4746e-01, + -2.7710e-02, -1.8066e-01, -2.9688e-01, 1.6016e+00, -2.6719e+00, + 7.7734e-01, -1.3516e+00, -1.9434e-01, -7.1289e-02, -5.0938e+00, + 2.4316e-01, 4.7266e-01, 4.6387e-02, -6.6406e-01, -2.1973e-01, + -6.7578e-01, -1.5723e-01, 9.5312e-01, -2.0020e-01, -1.7109e+00, + 5.8984e-01, -5.7422e-01, 5.1562e-01, 2.8320e-01, 1.4551e-01, + -1.8750e-01, -5.9814e-02, 3.6719e-01, -1.0059e-01, -1.5723e-01, + 2.0605e-01, -4.3359e-01, -8.2812e-01, 4.5654e-02, -6.6016e-01, + 1.4844e-01, 9.4727e-02, 3.8477e-01, -1.2578e+00, -3.3203e-01, + -8.5547e-01, 4.3359e-01, 4.2383e-01, -8.9453e-01, -5.0391e-01, + -5.6152e-02, -2.9219e+00, -2.4658e-02, 5.0391e-01, 9.8438e-01, + 7.2754e-02, -2.1582e-01, 6.3672e-01, 1.0000e+00]]], device=device).movedim(-1, 1).repeat(1, 1, length) + silence_latent[:, :, :head.shape[-1]] = head + return silence_latent + + +def get_layer_class(operations, layer_name): + if operations is not None and hasattr(operations, layer_name): + return getattr(operations, layer_name) + return getattr(nn, layer_name) + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=32768, base=1000000.0, dtype=None, device=None, operations=None): + super().__init__() + self.dim = dim + self.base = base + self.max_position_embeddings = max_position_embeddings + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._set_cos_sin_cache(max_position_embeddings, device=device, dtype=torch.get_default_dtype() if dtype is None else dtype) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len, x.device, x.dtype) + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype, device=x.device), + self.sin_cached[:seq_len].to(dtype=x.dtype, device=x.device), + ) + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin): + cos = cos.unsqueeze(0).unsqueeze(0) + sin = sin.unsqueeze(0).unsqueeze(0) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.gate_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.up_proj = Linear(hidden_size, intermediate_size, bias=False, dtype=dtype, device=device) + self.down_proj = Linear(intermediate_size, hidden_size, bias=False, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + +class TimestepEmbedding(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, scale: float = 1000, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.linear_1 = Linear(in_channels, time_embed_dim, bias=True, dtype=dtype, device=device) + self.act1 = nn.SiLU() + self.linear_2 = Linear(time_embed_dim, time_embed_dim, bias=True, dtype=dtype, device=device) + self.in_channels = in_channels + self.act2 = nn.SiLU() + self.time_proj = Linear(time_embed_dim, time_embed_dim * 6, dtype=dtype, device=device) + self.scale = scale + + def forward(self, t, dtype=None): + t_freq = timestep_embedding(t, self.in_channels, time_factor=self.scale) + temb = self.linear_1(t_freq.to(dtype=dtype)) + temb = self.act1(temb) + temb = self.linear_2(temb) + timestep_proj = self.time_proj(self.act2(temb)).view(-1, 6, temb.shape[-1]) + return temb, timestep_proj + +class AceStepAttention(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + rms_norm_eps=1e-6, + is_cross_attention=False, + sliding_window=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.is_cross_attention = is_cross_attention + self.sliding_window = sliding_window + + Linear = get_layer_class(operations, "Linear") + + self.q_proj = Linear(hidden_size, num_heads * head_dim, bias=False, dtype=dtype, device=device) + self.k_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) + self.v_proj = Linear(hidden_size, num_kv_heads * head_dim, bias=False, dtype=dtype, device=device) + self.o_proj = Linear(num_heads * head_dim, hidden_size, bias=False, dtype=dtype, device=device) + + self.q_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(head_dim, eps=rms_norm_eps, dtype=dtype, device=device) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + position_embeddings=None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + query_states = query_states.transpose(1, 2) + + if self.is_cross_attention and encoder_hidden_states is not None: + bsz_enc, kv_len, _ = encoder_hidden_states.size() + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + + key_states = key_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) + key_states = self.k_norm(key_states) + value_states = value_states.view(bsz_enc, kv_len, self.num_kv_heads, self.head_dim) + + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + else: + kv_len = q_len + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + key_states = self.k_norm(key_states) + value_states = value_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) + + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + n_rep = self.num_heads // self.num_kv_heads + if n_rep > 1: + key_states = key_states.repeat_interleave(n_rep, dim=1) + value_states = value_states.repeat_interleave(n_rep, dim=1) + + attn_bias = None + if self.sliding_window is not None and not self.is_cross_attention: + indices = torch.arange(q_len, device=query_states.device) + diff = indices.unsqueeze(1) - indices.unsqueeze(0) + in_window = torch.abs(diff) <= self.sliding_window + + window_bias = torch.zeros((q_len, kv_len), device=query_states.device, dtype=query_states.dtype) + min_value = torch.finfo(query_states.dtype).min + window_bias.masked_fill_(~in_window, min_value) + + window_bias = window_bias.unsqueeze(0).unsqueeze(0) + + if attn_bias is not None: + if attn_bias.dtype == torch.bool: + base_bias = torch.zeros_like(window_bias) + base_bias.masked_fill_(~attn_bias, min_value) + attn_bias = base_bias + window_bias + else: + attn_bias = attn_bias + window_bias + else: + attn_bias = window_bias + + attn_output = optimized_attention(query_states, key_states, value_states, self.num_heads, attn_bias, skip_reshape=True, low_precision_attention=False) + attn_output = self.o_proj(attn_output) + + return attn_output + +class AceStepDiTLayer(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + layer_type="full_attention", + sliding_window=128, + dtype=None, + device=None, + operations=None + ): + super().__init__() + + self_attn_window = sliding_window if layer_type == "sliding_attention" else None + + self.self_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.self_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=False, sliding_window=self_attn_window, + dtype=dtype, device=device, operations=operations + ) + + self.cross_attn_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.cross_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=True, dtype=dtype, device=device, operations=operations + ) + + self.mlp_norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) + + self.scale_shift_table = nn.Parameter(torch.empty(1, 6, hidden_size, dtype=dtype, device=device)) + + def forward( + self, + hidden_states, + temb, + encoder_hidden_states, + position_embeddings, + attention_mask=None, + encoder_attention_mask=None + ): + modulation = comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb + shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = modulation.chunk(6, dim=1) + + norm_hidden = self.self_attn_norm(hidden_states) + norm_hidden = norm_hidden * (1 + scale_msa) + shift_msa + + attn_out = self.self_attn( + norm_hidden, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + hidden_states = hidden_states + attn_out * gate_msa + + norm_hidden = self.cross_attn_norm(hidden_states) + attn_out = self.cross_attn( + norm_hidden, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask + ) + hidden_states = hidden_states + attn_out + + norm_hidden = self.mlp_norm(hidden_states) + norm_hidden = norm_hidden * (1 + c_scale_msa) + c_shift_msa + + mlp_out = self.mlp(norm_hidden) + hidden_states = hidden_states + mlp_out * c_gate_msa + + return hidden_states + +class AceStepEncoderLayer(nn.Module): + def __init__( + self, + hidden_size, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.self_attn = AceStepAttention( + hidden_size, num_heads, num_kv_heads, head_dim, rms_norm_eps, + is_cross_attention=False, dtype=dtype, device=device, operations=operations + ) + self.input_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.post_attention_layernorm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.mlp = MLP(hidden_size, intermediate_size, dtype=dtype, device=device, operations=operations) + + def forward(self, hidden_states, position_embeddings, attention_mask=None): + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + +class AceStepLyricEncoder(nn.Module): + def __init__( + self, + text_hidden_dim, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(text_hidden_dim, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + + def forward(self, inputs_embeds, attention_mask=None): + hidden_states = self.embed_tokens(inputs_embeds) + seq_len = hidden_states.shape[1] + cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) + position_embeddings = (cos, sin) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=position_embeddings, + attention_mask=attention_mask + ) + + hidden_states = self.norm(hidden_states) + return hidden_states + +class AceStepTimbreEncoder(nn.Module): + def __init__( + self, + timbre_hidden_dim, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(timbre_hidden_dim, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + + def unpack_timbre_embeddings(self, timbre_embs_packed, refer_audio_order_mask): + N, d = timbre_embs_packed.shape + device = timbre_embs_packed.device + B = N + counts = torch.bincount(refer_audio_order_mask, minlength=B) + max_count = counts.max().item() + + sorted_indices = torch.argsort( + refer_audio_order_mask * N + torch.arange(N, device=device), + stable=True + ) + sorted_batch_ids = refer_audio_order_mask[sorted_indices] + + positions = torch.arange(N, device=device) + batch_starts = torch.cat([torch.tensor([0], device=device), torch.cumsum(counts, dim=0)[:-1]]) + positions_in_sorted = positions - batch_starts[sorted_batch_ids] + + inverse_indices = torch.empty_like(sorted_indices) + inverse_indices[sorted_indices] = torch.arange(N, device=device) + positions_in_batch = positions_in_sorted[inverse_indices] + + indices_2d = refer_audio_order_mask * max_count + positions_in_batch + one_hot = F.one_hot(indices_2d, num_classes=B * max_count).to(timbre_embs_packed.dtype) + + timbre_embs_flat = one_hot.t() @ timbre_embs_packed + timbre_embs_unpack = timbre_embs_flat.view(B, max_count, d) + + mask_flat = (one_hot.sum(dim=0) > 0).long() + new_mask = mask_flat.view(B, max_count) + return timbre_embs_unpack, new_mask + + def forward(self, refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, attention_mask=None): + hidden_states = self.embed_tokens(refer_audio_acoustic_hidden_states_packed) + if hidden_states.dim() == 2: + hidden_states = hidden_states.unsqueeze(0) + + seq_len = hidden_states.shape[1] + cos, sin = self.rotary_emb(hidden_states, seq_len=seq_len) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + position_embeddings=(cos, sin), + attention_mask=attention_mask + ) + hidden_states = self.norm(hidden_states) + + flat_states = hidden_states[:, 0, :] + unpacked_embs, unpacked_mask = self.unpack_timbre_embeddings(flat_states, refer_audio_order_mask) + return unpacked_embs, unpacked_mask + + +def pack_sequences(hidden1, hidden2, mask1, mask2): + hidden_cat = torch.cat([hidden1, hidden2], dim=1) + B, L, D = hidden_cat.shape + + if mask1 is not None and mask2 is not None: + mask_cat = torch.cat([mask1, mask2], dim=1) + sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) + gather_idx = sort_idx.unsqueeze(-1).expand(B, L, D) + hidden_sorted = torch.gather(hidden_cat, 1, gather_idx) + lengths = mask_cat.sum(dim=1) + new_mask = (torch.arange(L, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1)) + else: + new_mask = None + hidden_sorted = hidden_cat + + return hidden_sorted, new_mask + +class AceStepConditionEncoder(nn.Module): + def __init__( + self, + text_hidden_dim, + timbre_hidden_dim, + hidden_size, + num_lyric_layers, + num_timbre_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.text_projector = Linear(text_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device) + + self.lyric_encoder = AceStepLyricEncoder( + text_hidden_dim=text_hidden_dim, + hidden_size=hidden_size, + num_layers=num_lyric_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + device=device, + operations=operations + ) + + self.timbre_encoder = AceStepTimbreEncoder( + timbre_hidden_dim=timbre_hidden_dim, + hidden_size=hidden_size, + num_layers=num_timbre_layers, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + rms_norm_eps=rms_norm_eps, + dtype=dtype, + device=device, + operations=operations + ) + + def forward( + self, + text_hidden_states=None, + text_attention_mask=None, + lyric_hidden_states=None, + lyric_attention_mask=None, + refer_audio_acoustic_hidden_states_packed=None, + refer_audio_order_mask=None + ): + text_emb = self.text_projector(text_hidden_states) + + lyric_emb = self.lyric_encoder( + inputs_embeds=lyric_hidden_states, + attention_mask=lyric_attention_mask + ) + + timbre_emb, timbre_mask = self.timbre_encoder( + refer_audio_acoustic_hidden_states_packed, + refer_audio_order_mask + ) + + merged_emb, merged_mask = pack_sequences(lyric_emb, timbre_emb, lyric_attention_mask, timbre_mask) + final_emb, final_mask = pack_sequences(merged_emb, text_emb, merged_mask, text_attention_mask) + + return final_emb, final_mask + +# -------------------------------------------------------------------------------- +# Main Diffusion Model (DiT) +# -------------------------------------------------------------------------------- + +class AceStepDiTModel(nn.Module): + def __init__( + self, + in_channels, + hidden_size, + num_layers, + num_heads, + num_kv_heads, + head_dim, + intermediate_size, + patch_size, + audio_acoustic_hidden_dim, + layer_types=None, + sliding_window=128, + rms_norm_eps=1e-6, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.patch_size = patch_size + self.rotary_emb = RotaryEmbedding( + head_dim, + base=1000000.0, + dtype=dtype, + device=device, + operations=operations + ) + + Conv1d = get_layer_class(operations, "Conv1d") + ConvTranspose1d = get_layer_class(operations, "ConvTranspose1d") + Linear = get_layer_class(operations, "Linear") + + self.proj_in = nn.Sequential( + nn.Identity(), + Conv1d( + in_channels, hidden_size, kernel_size=patch_size, stride=patch_size, + dtype=dtype, device=device)) + + self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) + self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations) + self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + + if layer_types is None: + layer_types = ["full_attention"] * num_layers + + if len(layer_types) < num_layers: + layer_types = list(itertools.islice(itertools.cycle(layer_types), num_layers)) + + self.layers = nn.ModuleList([ + AceStepDiTLayer( + hidden_size, num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + layer_type=layer_types[i], + sliding_window=sliding_window, + dtype=dtype, device=device, operations=operations + ) for i in range(num_layers) + ]) + + self.norm_out = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.proj_out = nn.Sequential( + nn.Identity(), + ConvTranspose1d(hidden_size, audio_acoustic_hidden_dim, kernel_size=patch_size, stride=patch_size, dtype=dtype, device=device) + ) + + self.scale_shift_table = nn.Parameter(torch.empty(1, 2, hidden_size, dtype=dtype, device=device)) + + def forward( + self, + hidden_states, + timestep, + timestep_r, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + context_latents + ): + temb_t, proj_t = self.time_embed(timestep, dtype=hidden_states.dtype) + temb_r, proj_r = self.time_embed_r(timestep - timestep_r, dtype=hidden_states.dtype) + temb = temb_t + temb_r + timestep_proj = proj_t + proj_r + + x = torch.cat([context_latents, hidden_states], dim=-1) + original_seq_len = x.shape[1] + + pad_length = 0 + if x.shape[1] % self.patch_size != 0: + pad_length = self.patch_size - (x.shape[1] % self.patch_size) + x = F.pad(x, (0, 0, 0, pad_length), mode='constant', value=0) + + x = x.transpose(1, 2) + x = self.proj_in(x) + x = x.transpose(1, 2) + + encoder_hidden_states = self.condition_embedder(encoder_hidden_states) + + seq_len = x.shape[1] + cos, sin = self.rotary_emb(x, seq_len=seq_len) + + for layer in self.layers: + x = layer( + hidden_states=x, + temb=timestep_proj, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=(cos, sin), + attention_mask=None, + encoder_attention_mask=None + ) + + shift, scale = (comfy.model_management.cast_to(self.scale_shift_table, dtype=temb.dtype, device=temb.device) + temb.unsqueeze(1)).chunk(2, dim=1) + x = self.norm_out(x) * (1 + scale) + shift + + x = x.transpose(1, 2) + x = self.proj_out(x) + x = x.transpose(1, 2) + + x = x[:, :original_seq_len, :] + return x + + +class AttentionPooler(nn.Module): + def __init__(self, hidden_size, num_layers, head_dim, rms_norm_eps, dtype=None, device=None, operations=None): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.norm = operations.RMSNorm(hidden_size, eps=rms_norm_eps, dtype=dtype, device=device) + self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) + self.special_token = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) + + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, 16, 8, head_dim, hidden_size * 3, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + + def forward(self, x): + B, T, P, D = x.shape + x = self.embed_tokens(x) + special = comfy.model_management.cast_to(self.special_token, device=x.device, dtype=x.dtype).expand(B, T, 1, -1) + x = torch.cat([special, x], dim=2) + x = x.view(B * T, P + 1, D) + + cos, sin = self.rotary_emb(x, seq_len=P + 1) + for layer in self.layers: + x = layer(x, (cos, sin)) + + x = self.norm(x) + return x[:, 0, :].view(B, T, D) + + +class FSQ(nn.Module): + def __init__( + self, + levels, + dim=None, + device=None, + dtype=None, + operations=None + ): + super().__init__() + + _levels = torch.tensor(levels, dtype=torch.int32, device=device) + self.register_buffer('_levels', _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1], dtype=torch.int32, device=device), dim=0) + self.register_buffer('_basis', _basis, persistent=False) + + self.codebook_dim = len(levels) + self.dim = dim if dim is not None else self.codebook_dim + + requires_projection = self.dim != self.codebook_dim + if requires_projection: + self.project_in = operations.Linear(self.dim, self.codebook_dim, device=device, dtype=dtype) + self.project_out = operations.Linear(self.codebook_dim, self.dim, device=device, dtype=dtype) + else: + self.project_in = nn.Identity() + self.project_out = nn.Identity() + + self.codebook_size = self._levels.prod().item() + + indices = torch.arange(self.codebook_size, device=device) + implicit_codebook = self._indices_to_codes(indices) + + if dtype is not None: + implicit_codebook = implicit_codebook.to(dtype) + + self.register_buffer('implicit_codebook', implicit_codebook, persistent=False) + + def bound(self, z): + levels_minus_1 = (comfy.model_management.cast_to(self._levels, device=z.device, dtype=z.dtype) - 1) + scale = 2. / levels_minus_1 + bracket = (levels_minus_1 * (torch.tanh(z) + 1) / 2.) + 0.5 + + zhat = bracket.floor() + bracket_ste = bracket + (zhat - bracket).detach() + + return scale * bracket_ste - 1. + + def _indices_to_codes(self, indices): + indices = indices.unsqueeze(-1) + codes_non_centered = (indices // self._basis) % self._levels + return codes_non_centered.float() * (2. / (self._levels.float() - 1)) - 1. + + def codes_to_indices(self, zhat): + zhat_normalized = (zhat + 1.) / (2. / (comfy.model_management.cast_to(self._levels, device=zhat.device, dtype=zhat.dtype) - 1)) + return (zhat_normalized * comfy.model_management.cast_to(self._basis, device=zhat.device, dtype=zhat.dtype)).sum(dim=-1).round().to(torch.int32) + + def forward(self, z): + orig_dtype = z.dtype + z = self.project_in(z) + + codes = self.bound(z) + indices = self.codes_to_indices(codes) + + out = self.project_out(codes) + return out.to(orig_dtype), indices + + +class ResidualFSQ(nn.Module): + def __init__( + self, + levels, + num_quantizers, + dim=None, + bound_hard_clamp=True, + device=None, + dtype=None, + operations=None, + **kwargs + ): + super().__init__() + + codebook_dim = len(levels) + dim = dim if dim is not None else codebook_dim + + requires_projection = codebook_dim != dim + if requires_projection: + self.project_in = operations.Linear(dim, codebook_dim, device=device, dtype=dtype) + self.project_out = operations.Linear(codebook_dim, dim, device=device, dtype=dtype) + else: + self.project_in = nn.Identity() + self.project_out = nn.Identity() + + self.layers = nn.ModuleList() + levels_tensor = torch.tensor(levels, device=device) + scales = [] + + for ind in range(num_quantizers): + scale_val = levels_tensor.float() ** -ind + scales.append(scale_val) + + self.layers.append(FSQ( + levels=levels, + dim=codebook_dim, + device=device, + dtype=dtype, + operations=operations + )) + + scales_tensor = torch.stack(scales) + if dtype is not None: + scales_tensor = scales_tensor.to(dtype) + self.register_buffer('scales', scales_tensor, persistent=False) + + if bound_hard_clamp: + val = 1 + (1 / (levels_tensor.float() - 1)) + if dtype is not None: + val = val.to(dtype) + self.register_buffer('soft_clamp_input_value', val, persistent=False) + + def get_output_from_indices(self, indices, dtype=torch.float32): + if indices.dim() == 2: + indices = indices.unsqueeze(-1) + + all_codes = [] + for i, layer in enumerate(self.layers): + idx = indices[..., i].long() + codes = F.embedding(idx, comfy.model_management.cast_to(layer.implicit_codebook, device=idx.device, dtype=dtype)) + all_codes.append(codes * comfy.model_management.cast_to(self.scales[i], device=idx.device, dtype=dtype)) + + codes_summed = torch.stack(all_codes).sum(dim=0) + return self.project_out(codes_summed) + + def forward(self, x): + x = self.project_in(x) + + if hasattr(self, 'soft_clamp_input_value'): + sc_val = comfy.model_management.cast_to(self.soft_clamp_input_value, device=x.device, dtype=x.dtype) + x = (x / sc_val).tanh() * sc_val + + quantized_out = torch.tensor(0., device=x.device, dtype=x.dtype) + residual = x + all_indices = [] + + for layer, scale in zip(self.layers, self.scales): + scale = comfy.model_management.cast_to(scale, device=x.device, dtype=x.dtype) + + quantized, indices = layer(residual / scale) + quantized = quantized * scale + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + all_indices.append(indices) + + quantized_out = self.project_out(quantized_out) + all_indices = torch.stack(all_indices, dim=-1) + + return quantized_out, all_indices + + +class AceStepAudioTokenizer(nn.Module): + def __init__( + self, + audio_acoustic_hidden_dim, + hidden_size, + pool_window_size, + fsq_dim, + fsq_levels, + fsq_input_num_quantizers, + num_layers, + head_dim, + rms_norm_eps, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.audio_acoustic_proj = Linear(audio_acoustic_hidden_dim, hidden_size, dtype=dtype, device=device) + self.attention_pooler = AttentionPooler( + hidden_size, num_layers, head_dim, rms_norm_eps, dtype=dtype, device=device, operations=operations + ) + self.pool_window_size = pool_window_size + self.fsq_dim = fsq_dim + self.quantizer = ResidualFSQ( + dim=fsq_dim, + levels=fsq_levels, + num_quantizers=fsq_input_num_quantizers, + bound_hard_clamp=True, + dtype=dtype, device=device, operations=operations + ) + + def forward(self, hidden_states): + hidden_states = self.audio_acoustic_proj(hidden_states) + hidden_states = self.attention_pooler(hidden_states) + quantized, indices = self.quantizer(hidden_states) + return quantized, indices + + def tokenize(self, x): + B, T, D = x.shape + P = self.pool_window_size + + if T % P != 0: + pad = P - (T % P) + x = F.pad(x, (0, 0, 0, pad)) + T = x.shape[1] + + T_patch = T // P + x = x.view(B, T_patch, P, D) + + quantized, indices = self.forward(x) + return quantized, indices + + +class AudioTokenDetokenizer(nn.Module): + def __init__( + self, + hidden_size, + pool_window_size, + audio_acoustic_hidden_dim, + num_layers, + head_dim, + dtype=None, + device=None, + operations=None + ): + super().__init__() + Linear = get_layer_class(operations, "Linear") + self.pool_window_size = pool_window_size + self.embed_tokens = Linear(hidden_size, hidden_size, dtype=dtype, device=device) + self.special_tokens = nn.Parameter(torch.empty(1, pool_window_size, hidden_size, dtype=dtype, device=device)) + self.rotary_emb = RotaryEmbedding(head_dim, dtype=dtype, device=device, operations=operations) + self.layers = nn.ModuleList([ + AceStepEncoderLayer( + hidden_size, 16, 8, head_dim, hidden_size * 3, 1e-6, + dtype=dtype, device=device, operations=operations + ) + for _ in range(num_layers) + ]) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) + self.proj_out = Linear(hidden_size, audio_acoustic_hidden_dim, dtype=dtype, device=device) + + def forward(self, x): + B, T, D = x.shape + x = self.embed_tokens(x) + x = x.unsqueeze(2).repeat(1, 1, self.pool_window_size, 1) + x = x + comfy.model_management.cast_to(self.special_tokens.expand(B, T, -1, -1), device=x.device, dtype=x.dtype) + x = x.view(B * T, self.pool_window_size, D) + + cos, sin = self.rotary_emb(x, seq_len=self.pool_window_size) + for layer in self.layers: + x = layer(x, (cos, sin)) + + x = self.norm(x) + x = self.proj_out(x) + return x.view(B, T * self.pool_window_size, -1) + + +class AceStepConditionGenerationModel(nn.Module): + def __init__( + self, + in_channels=192, + hidden_size=2048, + text_hidden_dim=1024, + timbre_hidden_dim=64, + audio_acoustic_hidden_dim=64, + num_dit_layers=24, + num_lyric_layers=8, + num_timbre_layers=4, + num_tokenizer_layers=2, + num_heads=16, + num_kv_heads=8, + head_dim=128, + intermediate_size=6144, + patch_size=2, + pool_window_size=5, + rms_norm_eps=1e-06, + timestep_mu=-0.4, + timestep_sigma=1.0, + data_proportion=0.5, + sliding_window=128, + layer_types=None, + fsq_dim=2048, + fsq_levels=[8, 8, 8, 5, 5, 5], + fsq_input_num_quantizers=1, + audio_model=None, + dtype=None, + device=None, + operations=None + ): + super().__init__() + self.dtype = dtype + self.timestep_mu = timestep_mu + self.timestep_sigma = timestep_sigma + self.data_proportion = data_proportion + self.pool_window_size = pool_window_size + + if layer_types is None: + layer_types = [] + for i in range(num_dit_layers): + layer_types.append("sliding_attention" if i % 2 == 0 else "full_attention") + + self.decoder = AceStepDiTModel( + in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim, + intermediate_size, patch_size, audio_acoustic_hidden_dim, + layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.encoder = AceStepConditionEncoder( + text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers, + num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.tokenizer = AceStepAudioTokenizer( + audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps, + dtype=dtype, device=device, operations=operations + ) + self.detokenizer = AudioTokenDetokenizer( + hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim, + dtype=dtype, device=device, operations=operations + ) + self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device)) + + def prepare_condition( + self, + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, + src_latents, chunk_masks, is_covers, + precomputed_lm_hints_25Hz=None, + audio_codes=None + ): + encoder_hidden, encoder_mask = self.encoder( + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask + ) + + if precomputed_lm_hints_25Hz is not None: + lm_hints = precomputed_lm_hints_25Hz + else: + if audio_codes is not None: + if audio_codes.shape[1] * 5 < src_latents.shape[1]: + audio_codes = torch.nn.functional.pad(audio_codes, (0, math.ceil(src_latents.shape[1] / 5) - audio_codes.shape[1]), "constant", 35847) + lm_hints_5Hz = self.tokenizer.quantizer.get_output_from_indices(audio_codes, dtype=text_hidden_states.dtype) + else: + lm_hints_5Hz, indices = self.tokenizer.tokenize(refer_audio_acoustic_hidden_states_packed) + + lm_hints = self.detokenizer(lm_hints_5Hz) + + lm_hints = lm_hints[:, :src_latents.shape[1], :] + if is_covers is None or is_covers is True: + src_latents = lm_hints + elif is_covers is False: + src_latents = refer_audio_acoustic_hidden_states_packed + + context_latents = torch.cat([src_latents, chunk_masks.to(src_latents.dtype)], dim=-1) + + return encoder_hidden, encoder_mask, context_latents + + def forward(self, x, timestep, context, lyric_embed=None, refer_audio=None, audio_codes=None, is_covers=None, replace_with_null_embeds=False, **kwargs): + text_attention_mask = None + lyric_attention_mask = None + refer_audio_order_mask = None + attention_mask = None + chunk_masks = None + src_latents = None + precomputed_lm_hints_25Hz = None + lyric_hidden_states = lyric_embed + text_hidden_states = context + refer_audio_acoustic_hidden_states_packed = refer_audio.movedim(-1, -2) + + x = x.movedim(-1, -2) + + if refer_audio_order_mask is None: + refer_audio_order_mask = torch.zeros((x.shape[0],), device=x.device, dtype=torch.long) + + if src_latents is None: + src_latents = x + + if chunk_masks is None: + chunk_masks = torch.ones_like(x) + + enc_hidden, enc_mask, context_latents = self.prepare_condition( + text_hidden_states, text_attention_mask, + lyric_hidden_states, lyric_attention_mask, + refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask, + src_latents, chunk_masks, is_covers, precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, audio_codes=audio_codes + ) + + if replace_with_null_embeds: + enc_hidden[:] = self.null_condition_emb.to(enc_hidden) + + out = self.decoder(hidden_states=x, + timestep=timestep, + timestep_r=timestep, + attention_mask=attention_mask, + encoder_hidden_states=enc_hidden, + encoder_attention_mask=enc_mask, + context_latents=context_latents + ) + + return out.movedim(-1, -2) diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py new file mode 100644 index 000000000..6fcf8df90 --- /dev/null +++ b/comfy/ldm/anima/model.py @@ -0,0 +1,214 @@ +from comfy.ldm.cosmos.predict2 import MiniTrainDIT +import torch +from torch import nn +import torch.nn.functional as F + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + x_embed = (x * cos) + (rotate_half(x) * sin) + return x_embed + + +class RotaryEmbedding(nn.Module): + def __init__(self, head_dim): + super().__init__() + self.rope_theta = 10000 + inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).to(dtype=torch.float) / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Attention(nn.Module): + def __init__(self, query_dim, context_dim, n_heads, head_dim, device=None, dtype=None, operations=None): + super().__init__() + + inner_dim = head_dim * n_heads + self.n_heads = n_heads + self.head_dim = head_dim + self.query_dim = query_dim + self.context_dim = context_dim + + self.q_proj = operations.Linear(query_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.k_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, device=device, dtype=dtype) + + self.v_proj = operations.Linear(context_dim, inner_dim, bias=False, device=device, dtype=dtype) + + self.o_proj = operations.Linear(inner_dim, query_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x, mask=None, context=None, position_embeddings=None, position_embeddings_context=None): + context = x if context is None else context + input_shape = x.shape[:-1] + q_shape = (*input_shape, self.n_heads, self.head_dim) + context_shape = context.shape[:-1] + kv_shape = (*context_shape, self.n_heads, self.head_dim) + + query_states = self.q_norm(self.q_proj(x).view(q_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(context).view(kv_shape)).transpose(1, 2) + value_states = self.v_proj(context).view(kv_shape).transpose(1, 2) + + if position_embeddings is not None: + assert position_embeddings_context is not None + cos, sin = position_embeddings + query_states = apply_rotary_pos_emb(query_states, cos, sin) + cos, sin = position_embeddings_context + key_states = apply_rotary_pos_emb(key_states, cos, sin) + + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=mask) + + attn_output = attn_output.transpose(1, 2).reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + def init_weights(self): + torch.nn.init.zeros_(self.o_proj.weight) + + +class TransformerBlock(nn.Module): + def __init__(self, source_dim, model_dim, num_heads=16, mlp_ratio=4.0, use_self_attn=False, layer_norm=False, device=None, dtype=None, operations=None): + super().__init__() + self.use_self_attn = use_self_attn + + if self.use_self_attn: + self.norm_self_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.self_attn = Attention( + query_dim=model_dim, + context_dim=model_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_cross_attn = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.cross_attn = Attention( + query_dim=model_dim, + context_dim=source_dim, + n_heads=num_heads, + head_dim=model_dim//num_heads, + device=device, + dtype=dtype, + operations=operations, + ) + + self.norm_mlp = operations.LayerNorm(model_dim, device=device, dtype=dtype) if layer_norm else operations.RMSNorm(model_dim, eps=1e-6, device=device, dtype=dtype) + self.mlp = nn.Sequential( + operations.Linear(model_dim, int(model_dim * mlp_ratio), device=device, dtype=dtype), + nn.GELU(), + operations.Linear(int(model_dim * mlp_ratio), model_dim, device=device, dtype=dtype) + ) + + def forward(self, x, context, target_attention_mask=None, source_attention_mask=None, position_embeddings=None, position_embeddings_context=None): + if self.use_self_attn: + normed = self.norm_self_attn(x) + attn_out = self.self_attn(normed, mask=target_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings) + x = x + attn_out + + normed = self.norm_cross_attn(x) + attn_out = self.cross_attn(normed, mask=source_attention_mask, context=context, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + x = x + attn_out + + x = x + self.mlp(self.norm_mlp(x)) + return x + + def init_weights(self): + torch.nn.init.zeros_(self.mlp[2].weight) + self.cross_attn.init_weights() + + +class LLMAdapter(nn.Module): + def __init__( + self, + source_dim=1024, + target_dim=1024, + model_dim=1024, + num_layers=6, + num_heads=16, + use_self_attn=True, + layer_norm=False, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + + self.embed = operations.Embedding(32128, target_dim, device=device, dtype=dtype) + if model_dim != target_dim: + self.in_proj = operations.Linear(target_dim, model_dim, device=device, dtype=dtype) + else: + self.in_proj = nn.Identity() + self.rotary_emb = RotaryEmbedding(model_dim//num_heads) + self.blocks = nn.ModuleList([ + TransformerBlock(source_dim, model_dim, num_heads=num_heads, use_self_attn=use_self_attn, layer_norm=layer_norm, device=device, dtype=dtype, operations=operations) for _ in range(num_layers) + ]) + self.out_proj = operations.Linear(model_dim, target_dim, device=device, dtype=dtype) + self.norm = operations.RMSNorm(target_dim, eps=1e-6, device=device, dtype=dtype) + + def forward(self, source_hidden_states, target_input_ids, target_attention_mask=None, source_attention_mask=None): + if target_attention_mask is not None: + target_attention_mask = target_attention_mask.to(torch.bool) + if target_attention_mask.ndim == 2: + target_attention_mask = target_attention_mask.unsqueeze(1).unsqueeze(1) + + if source_attention_mask is not None: + source_attention_mask = source_attention_mask.to(torch.bool) + if source_attention_mask.ndim == 2: + source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) + + context = source_hidden_states + x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype)) + position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) + position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) + position_embeddings = self.rotary_emb(x, position_ids) + position_embeddings_context = self.rotary_emb(x, position_ids_context) + for block in self.blocks: + x = block(x, context, target_attention_mask=target_attention_mask, source_attention_mask=source_attention_mask, position_embeddings=position_embeddings, position_embeddings_context=position_embeddings_context) + return self.norm(self.out_proj(x)) + + +class Anima(MiniTrainDIT): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations")) + + def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None): + if text_ids is not None: + out = self.llm_adapter(text_embeds, text_ids) + if t5xxl_weights is not None: + out = out * t5xxl_weights + + if out.shape[1] < 512: + out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1])) + return out + else: + return text_embeds + + def forward(self, x, timesteps, context, **kwargs): + t5xxl_ids = kwargs.pop("t5xxl_ids", None) + if t5xxl_ids is not None: + context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None)) + return super().forward(x, timesteps, context, **kwargs) diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 145e6e69a..e4e30cacd 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -136,16 +136,7 @@ class ResBlock(nn.Module): ops.Linear(c_hidden, c), ) - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - # Init weights - def _basic_init(module): - if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False) def _norm(self, x, norm): return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 2d5684348..df348a8ed 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -3,7 +3,6 @@ from torch import Tensor, nn from comfy.ldm.flux.layers import ( MLPEmbedder, - RMSNorm, ModulationOut, ) @@ -29,7 +28,7 @@ class Approximator(nn.Module): super().__init__() self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) - self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) + self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)]) self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) @property diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index 2e8ef0687..9fd865f20 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -152,6 +152,7 @@ class Chroma(nn.Module): transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: + transformer_options = transformer_options.copy() patches_replace = transformer_options.get("patches_replace", {}) # running on sequences img @@ -228,6 +229,7 @@ class Chroma(nn.Module): transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" + transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] for i, block in enumerate(self.single_blocks): transformer_options["block_index"] = i if i not in self.skip_dit: diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py index 3c7bc9b6b..08d31e0ba 100644 --- a/comfy/ldm/chroma_radiance/layers.py +++ b/comfy/ldm/chroma_radiance/layers.py @@ -4,8 +4,6 @@ from functools import lru_cache import torch from torch import nn -from comfy.ldm.flux.layers import RMSNorm - class NerfEmbedder(nn.Module): """ @@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module): # We now need to generate parameters for 3 matrices. total_params = 3 * hidden_size_x**2 * mlp_ratio self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) - self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device) self.mlp_ratio = mlp_ratio @@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module): class NerfFinalLayer(nn.Module): def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): super().__init__() - self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module): class NerfFinalLayerConv(nn.Module): def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() - self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) self.conv = operations.Conv2d( in_channels=hidden_size, out_channels=out_channels, diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 70d173889..4fb56165e 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -270,7 +270,7 @@ class ChromaRadiance(Chroma): bad_keys = tuple( k for k, v in overrides.items() - if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys) + if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys) ) if bad_keys: e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}" diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 07a4fc79f..2268bff38 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -13,6 +13,7 @@ from torchvision import transforms import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention +import comfy.ldm.common_dit def apply_rotary_pos_emb( t: torch.Tensor, @@ -334,7 +335,7 @@ class FinalLayer(nn.Module): device=None, dtype=None, operations=None ): super().__init__() - self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.layer_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = operations.Linear( hidden_size, spatial_patch_size * spatial_patch_size * temporal_patch_size * out_channels, bias=False, device=device, dtype=dtype ) @@ -462,6 +463,8 @@ class Block(nn.Module): extra_per_block_pos_emb: Optional[torch.Tensor] = None, transformer_options: Optional[dict] = {}, ) -> torch.Tensor: + residual_dtype = x_B_T_H_W_D.dtype + compute_dtype = emb_B_T_D.dtype if extra_per_block_pos_emb is not None: x_B_T_H_W_D = x_B_T_H_W_D + extra_per_block_pos_emb @@ -511,7 +514,7 @@ class Block(nn.Module): result_B_T_H_W_D = rearrange( self.self_attn( # normalized_x_B_T_HW_D, - rearrange(normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + rearrange(normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), None, rope_emb=rope_emb_L_1_1_D, transformer_options=transformer_options, @@ -521,7 +524,7 @@ class Block(nn.Module): h=H, w=W, ) - x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D * result_B_T_H_W_D + x_B_T_H_W_D = x_B_T_H_W_D + gate_self_attn_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) def _x_fn( _x_B_T_H_W_D: torch.Tensor, @@ -535,7 +538,7 @@ class Block(nn.Module): ) _result_B_T_H_W_D = rearrange( self.cross_attn( - rearrange(_normalized_x_B_T_H_W_D, "b t h w d -> b (t h w) d"), + rearrange(_normalized_x_B_T_H_W_D.to(compute_dtype), "b t h w d -> b (t h w) d"), crossattn_emb, rope_emb=rope_emb_L_1_1_D, transformer_options=transformer_options, @@ -554,7 +557,7 @@ class Block(nn.Module): shift_cross_attn_B_T_1_1_D, transformer_options=transformer_options, ) - x_B_T_H_W_D = result_B_T_H_W_D * gate_cross_attn_B_T_1_1_D + x_B_T_H_W_D + x_B_T_H_W_D = result_B_T_H_W_D.to(residual_dtype) * gate_cross_attn_B_T_1_1_D.to(residual_dtype) + x_B_T_H_W_D normalized_x_B_T_H_W_D = _fn( x_B_T_H_W_D, @@ -562,8 +565,8 @@ class Block(nn.Module): scale_mlp_B_T_1_1_D, shift_mlp_B_T_1_1_D, ) - result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D) - x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D * result_B_T_H_W_D + result_B_T_H_W_D = self.mlp(normalized_x_B_T_H_W_D.to(compute_dtype)) + x_B_T_H_W_D = x_B_T_H_W_D + gate_mlp_B_T_1_1_D.to(residual_dtype) * result_B_T_H_W_D.to(residual_dtype) return x_B_T_H_W_D @@ -835,6 +838,8 @@ class MiniTrainDIT(nn.Module): padding_mask: Optional[torch.Tensor] = None, **kwargs, ): + orig_shape = list(x.shape) + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_temporal, self.patch_spatial, self.patch_spatial)) x_B_C_T_H_W = x timesteps_B_T = timesteps crossattn_emb = context @@ -873,6 +878,14 @@ class MiniTrainDIT(nn.Module): "extra_per_block_pos_emb": extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D, "transformer_options": kwargs.get("transformer_options", {}), } + + # The residual stream for this model has large values. To make fp16 compute_dtype work, we keep the residual stream + # in fp32, but run attention and MLP modules in fp16. + # An alternate method that clamps fp16 values "works" in the sense that it makes coherent images, but there is noticeable + # quality degradation and visual artifacts. + if x_B_T_H_W_D.dtype == torch.float16: + x_B_T_H_W_D = x_B_T_H_W_D.float() + for block in self.blocks: x_B_T_H_W_D = block( x_B_T_H_W_D, @@ -881,6 +894,6 @@ class MiniTrainDIT(nn.Module): **block_kwargs, ) - x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D, t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) - x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O) + x_B_T_H_W_O = self.final_layer(x_B_T_H_W_D.to(crossattn_emb.dtype), t_embedding_B_T_D, adaln_lora_B_T_3D=adaln_lora_B_T_3D) + x_B_C_Tt_Hp_Wp = self.unpatchify(x_B_T_H_W_O)[:, :, :orig_shape[-3], :orig_shape[-2], :orig_shape[-1]] return x_B_C_Tt_Hp_Wp diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 60f2bdae2..e28d704b4 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -5,9 +5,9 @@ import torch from torch import Tensor, nn from .math import attention, rope -import comfy.ops -import comfy.ldm.common_dit +# Fix import for some custom nodes, TODO: delete eventually. +RMSNorm = None class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list): @@ -87,20 +87,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) - - def forward(self, x: Tensor): - return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6) - class QKNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): super().__init__() - self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) - self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) + self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device) + self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: q = self.query_norm(q) @@ -152,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor * m_mult else: for d in modulation_dims: - tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] + tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1] if m_add is not None: - tensor[:, d[0]:d[1]] += m_add[:, d[2]] + tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1] return tensor @@ -169,7 +161,7 @@ class SiLUActivation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -197,8 +189,6 @@ class DoubleStreamBlock(nn.Module): self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) - self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): if self.modulation: img_mod1, img_mod2 = self.img_mod(vec) @@ -206,6 +196,9 @@ class DoubleStreamBlock(nn.Module): else: (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + # prepare image for attention img_modulated = self.img_norm1(img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) @@ -224,32 +217,30 @@ class DoubleStreamBlock(nn.Module): del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - if self.flipped_img_txt: - q = torch.cat((img_q, txt_q), dim=2) - del img_q, txt_q - k = torch.cat((img_k, txt_k), dim=2) - del img_k, txt_k - v = torch.cat((img_v, txt_v), dim=2) - del img_v, txt_v - # run actual attention - attn = attention(q, k, v, - pe=pe, mask=attn_mask, transformer_options=transformer_options) - del q, k, v + q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q + k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k + v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v - img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] - else: - q = torch.cat((txt_q, img_q), dim=2) - del txt_q, img_q - k = torch.cat((txt_k, img_k), dim=2) - del txt_k, img_k - v = torch.cat((txt_v, img_v), dim=2) - del txt_v, img_v - # run actual attention - attn = attention(q, k, v, - pe=pe, mask=attn_mask, transformer_options=transformer_options) - del q, k, v + extra_options["img_slice"] = [txt.shape[1], q.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] + # run actual attention + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v + + if "attn1_output_patch" in transformer_patches: + patch = transformer_patches["attn1_output_patch"] + for p in patch: + attn = p(attn, extra_options) + + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) @@ -328,15 +319,30 @@ class SingleStreamBlock(nn.Module): else: mod = vec + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) del qkv q, k = self.norm(q, k, v) + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options) + q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask) + # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) del q, k, v + + if "attn1_output_patch" in transformer_patches: + patch = transformer_patches["attn1_output_patch"] + for p in patch: + attn = p(attn, extra_options) + # compute activation in mlp stream, cat again and run second linear layer if self.yak_mlp: mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6a22df8bc..824daf5e6 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,6 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import logging def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -13,7 +14,6 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transforme x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x - def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): @@ -28,13 +28,37 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) return out.to(dtype=torch.float32, device=pos.device) -def apply_rope1(x: Tensor, freqs_cis: Tensor): + +def _apply_rope1(x: Tensor, freqs_cis: Tensor): x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) + if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]: + freqs_cis = freqs_cis[:, :, :x_.shape[2]] x_out = freqs_cis[..., 0] * x_[..., 0] x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) return x_out.reshape(*x.shape).type_as(x) -def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): + +def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + + +try: + import comfy.quant_ops + q_apply_rope = comfy.quant_ops.ck.apply_rope + q_apply_rope1 = comfy.quant_ops.ck.apply_rope1 + def apply_rope(xq, xk, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope(xq, xk, freqs_cis) + else: + return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) + def apply_rope1(x, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope1(x, freqs_cis) + else: + return q_apply_rope1(x, freqs_cis) +except: + logging.warning("No comfy kitchen, using old apply_rope functions.") + apply_rope = _apply_rope + apply_rope1 = _apply_rope1 diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index f40c2a7a9..2020326c2 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -16,7 +16,6 @@ from .layers import ( SingleStreamBlock, timestep_embedding, Modulation, - RMSNorm ) @dataclass @@ -45,6 +44,22 @@ class FluxParams: txt_norm: bool = False +def invert_slices(slices, length): + sorted_slices = sorted(slices) + result = [] + current = 0 + + for start, end in sorted_slices: + if current < start: + result.append((current, start)) + current = max(current, end) + + if current < length: + result.append((current, length)) + + return result + + class Flux(nn.Module): """ Transformer model for flow matching on sequences. @@ -81,7 +96,7 @@ class Flux(nn.Module): self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) if params.txt_norm: - self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device) else: self.txt_norm = None @@ -139,10 +154,12 @@ class Flux(nn.Module): y: Tensor, guidance: Tensor = None, control = None, + timestep_zero_index=None, transformer_options={}, attn_mask: Tensor = None, ) -> Tensor: + transformer_options = transformer_options.copy() patches = transformer_options.get("patches", {}) patches_replace = transformer_options.get("patches_replace", {}) if img.ndim != 3 or txt.ndim != 3: @@ -164,13 +181,9 @@ class Flux(nn.Module): txt = self.txt_norm(txt) txt = self.txt_in(txt) - vec_orig = vec - if self.params.global_modulation: - vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) - if "post_input" in patches: for p in patches["post_input"]: - out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) + out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) img = out["img"] txt = out["txt"] img_ids = out["img_ids"] @@ -182,6 +195,24 @@ class Flux(nn.Module): else: pe = None + vec_orig = vec + txt_vec = vec + extra_kwargs = {} + if timestep_zero_index is not None: + modulation_dims = [] + batch = vec.shape[0] // 2 + vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1) + invert = invert_slices(timestep_zero_index, img.shape[1]) + for s in invert: + modulation_dims.append((s[0], s[1], 0)) + for s in timestep_zero_index: + modulation_dims.append((s[0], s[1], 1)) + extra_kwargs["modulation_dims_img"] = modulation_dims + txt_vec = vec[:batch] + + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec)) + blocks_replace = patches_replace.get("dit", {}) transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["block_type"] = "double" @@ -195,7 +226,8 @@ class Flux(nn.Module): vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) + transformer_options=args.get("transformer_options"), + **extra_kwargs) return out out = blocks_replace[("double_block", i)]({"img": img, @@ -213,7 +245,8 @@ class Flux(nn.Module): vec=vec, pe=pe, attn_mask=attn_mask, - transformer_options=transformer_options) + transformer_options=transformer_options, + **extra_kwargs) if control is not None: # Controlnet control_i = control.get("input") @@ -230,8 +263,15 @@ class Flux(nn.Module): if self.params.global_modulation: vec, _ = self.single_stream_modulation(vec_orig) + extra_kwargs = {} + if timestep_zero_index is not None: + lambda a: 0 if a == 0 else a + txt.shape[1] + modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims)) + extra_kwargs["modulation_dims"] = modulation_dims_combined + transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" + transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] for i, block in enumerate(self.single_blocks): transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: @@ -241,7 +281,8 @@ class Flux(nn.Module): vec=args["vec"], pe=args["pe"], attn_mask=args.get("attn_mask"), - transformer_options=args.get("transformer_options")) + transformer_options=args.get("transformer_options"), + **extra_kwargs) return out out = blocks_replace[("single_block", i)]({"img": img, @@ -252,7 +293,7 @@ class Flux(nn.Module): {"original_block": block_wrap}) img = out["img"] else: - img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) + img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs) if control is not None: # Controlnet control_o = control.get("output") @@ -263,7 +304,11 @@ class Flux(nn.Module): img = img[:, txt.shape[1] :, ...] - img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) + extra_kwargs = {} + if timestep_zero_index is not None: + extra_kwargs["modulation_dims"] = modulation_dims + + img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels) return img def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): @@ -311,13 +356,16 @@ class Flux(nn.Module): w_len = ((w_orig + (patch_size // 2)) // patch_size) img, img_ids = self.process_img(x, transformer_options=transformer_options) img_tokens = img.shape[1] + timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) + timestep_zero = ref_latents_method == "index_timestep_zero" for ref in ref_latents: - if ref_latents_method == "index": + if ref_latents_method in ("index", "index_timestep_zero"): index += self.params.ref_index_scale h_offset = 0 w_offset = 0 @@ -338,9 +386,16 @@ class Flux(nn.Module): h = max(h, ref.shape[-2] + h_offset) w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) + if timestep_zero: + if index > 0: + timestep = torch.cat([timestep, timestep * 0], dim=0) + timestep_zero_index = [[img_tokens, img_ids.shape[1]]] + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) @@ -348,6 +403,6 @@ class Flux(nn.Module): for i in self.params.txt_ids_dims: txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) - out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index d48d9d642..f67ba84e9 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -343,6 +343,7 @@ class CrossAttention(nn.Module): k.reshape(b, s2, self.num_heads * self.head_dim), v, heads=self.num_heads, + low_precision_attention=False, ) out = self.out_proj(x) @@ -412,6 +413,7 @@ class Attention(nn.Module): key.reshape(B, N, self.num_heads * self.head_dim), value, heads=self.num_heads, + low_precision_attention=False, ) x = self.out_proj(x) diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 55ab550f8..b94cdfa87 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, - flipped_img_txt=True, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -305,6 +304,7 @@ class HunyuanVideo(nn.Module): control=None, transformer_options={}, ) -> Tensor: + transformer_options = transformer_options.copy() patches_replace = transformer_options.get("patches_replace", {}) initial_shape = list(img.shape) @@ -378,14 +378,14 @@ class HunyuanVideo(nn.Module): extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) - ids = torch.cat((img_ids, txt_ids), dim=1) + ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) img_len = img.shape[1] if txt_mask is not None: attn_mask_len = img_len + txt.shape[1] attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) - attn_mask[:, 0, img_len:] = txt_mask + attn_mask[:, 0, :txt.shape[1]] = txt_mask else: attn_mask = None @@ -413,10 +413,11 @@ class HunyuanVideo(nn.Module): if add is not None: img += add - img = torch.cat((img, txt), 1) + img = torch.cat((txt, img), 1) transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" + transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] for i, block in enumerate(self.single_blocks): transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: @@ -435,9 +436,9 @@ class HunyuanVideo(nn.Module): if i < len(control_o): add = control_o[i] if add is not None: - img[:, : img_len] += add + img[:, txt.shape[1]: img_len + txt.shape[1]] += add - img = img[:, : img_len] + img = img[:, txt.shape[1]: img_len + txt.shape[1]] if ref_latent is not None: img = img[:, ref_latent.shape[1]:] diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 85f515f67..1f68144e2 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -3,7 +3,8 @@ import torch.nn as nn import torch.nn.functional as F from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm -import model_management, model_patcher +import comfy.model_management +import comfy.model_patcher class SRResidualCausalBlock3D(nn.Module): def __init__(self, channels: int): @@ -102,20 +103,20 @@ UPSAMPLERS = { class HunyuanVideo15SRModel(): def __init__(self, model_type, config): - self.load_device = model_management.vae_device() - offload_device = model_management.vae_offload_device() - self.dtype = model_management.vae_dtype(self.load_device) + self.load_device = comfy.model_management.vae_device() + offload_device = comfy.model_management.vae_offload_device() + self.dtype = comfy.model_management.vae_dtype(self.load_device) self.model_class = UPSAMPLERS.get(model_type) self.model = self.model_class(**config).eval() - self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) def load_sd(self, sd): - return self.model.load_state_dict(sd, strict=True) + return self.model.load_state_dict(sd, strict=True, assign=self.patcher.is_dynamic()) def get_sd(self): return self.model.state_dict() def resample_latent(self, latent): - model_management.load_model_gpu(self.patcher) + comfy.model_management.load_model_gpu(self.patcher) return self.model(latent.to(self.load_device)) diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py new file mode 100644 index 000000000..6f2ba41ef --- /dev/null +++ b/comfy/ldm/lightricks/av_model.py @@ -0,0 +1,1071 @@ +from typing import Tuple +import torch +import torch.nn as nn +from comfy.ldm.lightricks.model import ( + ADALN_BASE_PARAMS_COUNT, + ADALN_CROSS_ATTN_PARAMS_COUNT, + CrossAttention, + FeedForward, + AdaLayerNormSingle, + PixArtAlphaTextProjection, + NormSingleLinearTextProjection, + LTXVModel, + apply_cross_attention_adaln, + compute_prompt_timestep, +) +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector +import comfy.ldm.common_dit + +class CompressedTimestep: + """Store video timestep embeddings in compressed form using per-frame indexing.""" + __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') + + def __init__(self, tensor: torch.Tensor, patches_per_frame: int): + """ + tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame + patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression + """ + self.batch_size, num_tokens, self.feature_dim = tensor.shape + + # Check if compression is valid (num_tokens must be divisible by patches_per_frame) + if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame: + self.patches_per_frame = patches_per_frame + self.num_frames = num_tokens // patches_per_frame + + # Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame + # All patches in a frame are identical, so we only keep the first one + reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) + self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] + else: + # Not divisible or too small - store directly without compression + self.patches_per_frame = 1 + self.num_frames = num_tokens + self.data = tensor + + def expand(self): + """Expand back to original tensor.""" + if self.patches_per_frame == 1: + return self.data + + # [batch, frames, feature_dim] -> [batch, frames, patches_per_frame, feature_dim] -> [batch, tokens, feature_dim] + expanded = self.data.unsqueeze(2).expand(self.batch_size, self.num_frames, self.patches_per_frame, self.feature_dim) + return expanded.reshape(self.batch_size, -1, self.feature_dim) + + def expand_for_computation(self, scale_shift_table: torch.Tensor, batch_size: int, indices: slice = slice(None, None)): + """Compute ada values on compressed per-frame data, then expand spatially.""" + num_ada_params = scale_shift_table.shape[0] + + # No compression - compute directly + if self.patches_per_frame == 1: + num_tokens = self.data.shape[1] + dim_per_param = self.feature_dim // num_ada_params + reshaped = self.data.reshape(batch_size, num_tokens, num_ada_params, dim_per_param)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=self.data.device, dtype=self.data.dtype) + ada_values = (table_values + reshaped).unbind(dim=2) + return ada_values + + # Compressed: compute on per-frame data then expand spatially + # Reshape: [batch, frames, feature_dim] -> [batch, frames, num_ada_params, dim_per_param] + frame_reshaped = self.data.reshape(batch_size, self.num_frames, num_ada_params, -1)[:, :, indices, :] + table_values = scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to( + device=self.data.device, dtype=self.data.dtype + ) + frame_ada = (table_values + frame_reshaped).unbind(dim=2) + + # Expand each ada parameter spatially: [batch, frames, dim] -> [batch, frames, patches, dim] -> [batch, tokens, dim] + return tuple( + frame_val.unsqueeze(2).expand(batch_size, self.num_frames, self.patches_per_frame, -1) + .reshape(batch_size, -1, frame_val.shape[-1]) + for frame_val in frame_ada + ) + +class BasicAVTransformerBlock(nn.Module): + def __init__( + self, + v_dim, + a_dim, + v_heads, + a_heads, + vd_head, + ad_head, + v_context_dim=None, + a_context_dim=None, + attn_precision=None, + apply_gated_attention=False, + cross_attention_adaln=False, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln + + self.attn1 = CrossAttention( + query_dim=v_dim, + heads=v_heads, + dim_head=vd_head, + context_dim=None, + attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn1 = CrossAttention( + query_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + context_dim=None, + attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, + dtype=dtype, + device=device, + operations=operations, + ) + + self.attn2 = CrossAttention( + query_dim=v_dim, + context_dim=v_context_dim, + heads=v_heads, + dim_head=vd_head, + attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn2 = CrossAttention( + query_dim=a_dim, + context_dim=a_context_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Video, K,V: Audio + self.audio_to_video_attn = CrossAttention( + query_dim=v_dim, + context_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = CrossAttention( + query_dim=a_dim, + context_dim=v_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, + dtype=dtype, + device=device, + operations=operations, + ) + + self.ff = FeedForward( + v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + self.audio_ff = FeedForward( + a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype)) + self.audio_scale_shift_table = nn.Parameter( + torch.empty(num_ada_params, a_dim, device=device, dtype=dtype) + ) + + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype)) + + self.scale_shift_table_a2v_ca_audio = nn.Parameter( + torch.empty(5, a_dim, device=device, dtype=dtype) + ) + self.scale_shift_table_a2v_ca_video = nn.Parameter( + torch.empty(5, v_dim, device=device, dtype=dtype) + ) + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) + ): + if isinstance(timestep, CompressedTimestep): + return timestep.expand_for_computation(scale_shift_table, batch_size, indices) + + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ): + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], + batch_size, + scale_shift_timestep, + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], + batch_size, + gate_timestep, + ) + + return (*scale_shift_ada_values, *gate_ada_values) + + def _apply_text_cross_attention( + self, x, context, attn, scale_shift_table, prompt_scale_shift_table, + timestep, prompt_timestep, attention_mask, transformer_options, + ): + """Apply text cross-attention, with optional ADaLN modulation.""" + if self.cross_attention_adaln: + shift_q, scale_q, gate = self.get_ada_values( + scale_shift_table, x.shape[0], timestep, slice(6, 9) + ) + return apply_cross_attention_adaln( + x, context, attn, shift_q, scale_q, gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask, transformer_options, + ) + return attn( + comfy.ldm.common_dit.rms_norm(x), context=context, + mask=attention_mask, transformer_options=transformer_options, + ) + + def forward( + self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, + v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, + v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, + v_prompt_timestep=None, a_prompt_timestep=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + run_vx = transformer_options.get("run_vx", True) + run_ax = transformer_options.get("run_ax", True) + + vx, ax = x + run_ax = run_ax and ax.numel() > 0 + run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 + run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) + + # video + if run_vx: + # video self-attention + vshift_msa, vscale_msa = (self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 2))) + norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa + del vshift_msa, vscale_msa + attn1_out = self.attn1(norm_vx, pe=v_pe, mask=self_attention_mask, transformer_options=transformer_options) + del norm_vx + # video cross-attention + vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] + vx.addcmul_(attn1_out, vgate_msa) + del vgate_msa, attn1_out + vx.add_(self._apply_text_cross_attention( + vx, v_context, self.attn2, self.scale_shift_table, + getattr(self, 'prompt_scale_shift_table', None), + v_timestep, v_prompt_timestep, attention_mask, transformer_options,) + ) + + # audio + if run_ax: + # audio self-attention + ashift_msa, ascale_msa = (self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 2))) + norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa + del ashift_msa, ascale_msa + attn1_out = self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) + del norm_ax + # audio cross-attention + agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] + ax.addcmul_(attn1_out, agate_msa) + del agate_msa, attn1_out + ax.add_(self._apply_text_cross_attention( + ax, a_context, self.audio_attn2, self.audio_scale_shift_table, + getattr(self, 'audio_prompt_scale_shift_table', None), + a_timestep, a_prompt_timestep, attention_mask, transformer_options,) + ) + + # video - audio cross attention. + if run_a2v or run_v2a: + vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) + ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) + + # audio to video cross attention + if run_a2v: + scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v = self.get_ada_values( + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[:2] + scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v = self.get_ada_values( + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[:2] + + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_a2v_v) + shift_ca_video_hidden_states_a2v_v + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + shift_ca_audio_hidden_states_a2v + del scale_ca_video_hidden_states_a2v_v, shift_ca_video_hidden_states_a2v_v, scale_ca_audio_hidden_states_a2v, shift_ca_audio_hidden_states_a2v + + a2v_out = self.audio_to_video_attn(vx_scaled, context=ax_scaled, pe=v_cross_pe, k_pe=a_cross_pe, transformer_options=transformer_options) + del vx_scaled, ax_scaled + + gate_out_a2v = self.get_ada_values(self.scale_shift_table_a2v_ca_video[4:, :], vx.shape[0], v_cross_gate_timestep)[0] + vx.addcmul_(a2v_out, gate_out_a2v) + del gate_out_a2v, a2v_out + + # video to audio cross attention + if run_v2a: + scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a = self.get_ada_values( + self.scale_shift_table_a2v_ca_audio[:4, :], ax.shape[0], a_cross_scale_shift_timestep)[2:4] + scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a = self.get_ada_values( + self.scale_shift_table_a2v_ca_video[:4, :], vx.shape[0], v_cross_scale_shift_timestep)[2:4] + + ax_scaled = ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + shift_ca_audio_hidden_states_v2a + vx_scaled = vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + shift_ca_video_hidden_states_v2a + del scale_ca_video_hidden_states_v2a, shift_ca_video_hidden_states_v2a, scale_ca_audio_hidden_states_v2a, shift_ca_audio_hidden_states_v2a + + v2a_out = self.video_to_audio_attn(ax_scaled, context=vx_scaled, pe=a_cross_pe, k_pe=v_cross_pe, transformer_options=transformer_options) + del ax_scaled, vx_scaled + + gate_out_v2a = self.get_ada_values(self.scale_shift_table_a2v_ca_audio[4:, :], ax.shape[0], a_cross_gate_timestep)[0] + ax.addcmul_(v2a_out, gate_out_v2a) + del gate_out_v2a, v2a_out + + del vx_norm3, ax_norm3 + + # video feedforward + if run_vx: + vshift_mlp, vscale_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, 5)) + vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp + del vshift_mlp, vscale_mlp + + ff_out = self.ff(vx_scaled) + del vx_scaled + + vgate_mlp = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(5, 6))[0] + vx.addcmul_(ff_out, vgate_mlp) + del vgate_mlp, ff_out + + # audio feedforward + if run_ax: + ashift_mlp, ascale_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, 5)) + ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp + del ashift_mlp, ascale_mlp + + ff_out = self.audio_ff(ax_scaled) + del ax_scaled + + agate_mlp = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(5, 6))[0] + ax.addcmul_(ff_out, agate_mlp) + del agate_mlp, ff_out + + return vx, ax + + +class LTXAVModel(LTXVModel): + """LTXAV model for audio-video generation.""" + + def __init__( + self, + in_channels=128, + audio_in_channels=128, + cross_attention_dim=4096, + audio_cross_attention_dim=2048, + attention_head_dim=128, + audio_attention_head_dim=64, + num_attention_heads=32, + audio_num_attention_heads=32, + caption_channels=3840, + num_layers=48, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + audio_positional_embedding_max_pos=[20], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier=1000.0, + av_ca_timestep_scale_multiplier=1.0, + apply_gated_attention=False, + caption_proj_before_connector=False, + cross_attention_adaln=False, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + # Store audio-specific parameters + self.audio_in_channels = audio_in_channels + self.audio_cross_attention_dim = audio_cross_attention_dim + self.audio_attention_head_dim = audio_attention_head_dim + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.apply_gated_attention = apply_gated_attention + + # Calculate audio dimensions + self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + self.audio_out_channels = audio_in_channels + + # Audio-specific constants + self.num_audio_channels = 8 + self.audio_frequency_bins = 16 + + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXAV-specific components.""" + # Audio-specific projections + self.audio_patchify_proj = self.operations.Linear( + self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device + ) + + # Audio-specific AdaLN + audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=audio_embedding_coefficient, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + + if self.cross_attention_adaln: + self.audio_prompt_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=2, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_prompt_adaln_single = None + + num_scale_shift_values = 4 + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + + # Audio caption projection + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.audio_caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_caption_projection = lambda a: a + else: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + connector_split_rope = kwargs.get("rope_type", "split") == "split" + connector_gated_attention = kwargs.get("connector_apply_gated_attention", False) + attention_head_dim = kwargs.get("connector_attention_head_dim", 128) + num_attention_heads = kwargs.get("connector_num_attention_heads", 30) + num_layers = kwargs.get("connector_num_layers", 2) + + self.audio_embeddings_connector = Embeddings1DConnector( + attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim), + num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads), + num_layers=num_layers, + split_rope=connector_split_rope, + double_precision_rope=True, + apply_gated_attention=connector_gated_attention, + dtype=dtype, + device=device, + operations=self.operations, + ) + + self.video_embeddings_connector = Embeddings1DConnector( + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + num_layers=num_layers, + split_rope=connector_split_rope, + double_precision_rope=True, + apply_gated_attention=connector_gated_attention, + dtype=dtype, + device=device, + operations=self.operations, + ) + + def preprocess_text_embeds(self, context, unprocessed=False): + # LTXv2 fully processed context has dimension of self.caption_channels * 2 + # LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim + if not unprocessed: + if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2): + return context + if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim: + context_vid = context[:, :, :self.cross_attention_dim] + context_audio = context[:, :, self.cross_attention_dim:] + else: + context_vid = context + context_audio = context + if self.caption_proj_before_connector: + context_vid = self.caption_projection(context_vid) + context_audio = self.audio_caption_projection(context_audio) + out_vid = self.video_embeddings_connector(context_vid)[0] + out_audio = self.audio_embeddings_connector(context_audio)[0] + return torch.concat((out_vid, out_audio), dim=-1) + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXAV.""" + self.transformer_blocks = nn.ModuleList( + [ + BasicAVTransformerBlock( + v_dim=self.inner_dim, + a_dim=self.audio_inner_dim, + v_heads=self.num_attention_heads, + a_heads=self.audio_num_attention_heads, + vd_head=self.attention_head_dim, + ad_head=self.audio_attention_head_dim, + v_context_dim=self.cross_attention_dim, + a_context_dim=self.audio_cross_attention_dim, + apply_gated_attention=self.apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, + dtype=dtype, + device=device, + operations=self.operations, + ) + for _ in range(self.num_layers) + ] + ) + + def _init_output_components(self, device, dtype): + """Initialize output components for LTXAV.""" + # Video output components + super()._init_output_components(device, dtype) + # Audio output components + self.audio_scale_shift_table = nn.Parameter( + torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device) + ) + self.audio_norm_out = self.operations.LayerNorm( + self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.audio_proj_out = self.operations.Linear( + self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device + ) + self.a_patchifier = AudioPatchifier(1, start_end=True) + + def separate_audio_and_video_latents(self, x, audio_length): + """Separate audio and video latents from combined input.""" + # vx = x[:, : self.in_channels] + # ax = x[:, self.in_channels :] + # + # ax = ax.reshape(ax.shape[0], -1) + # ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins] + # + # ax = ax.reshape( + # ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins + # ) + + vx = x[0] + ax = x[1] if len(x) > 1 else torch.zeros( + (vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins), + device=vx.device, dtype=vx.dtype + ) + return vx, ax + + def recombine_audio_and_video_latents(self, vx, ax, target_shape=None): + if ax.numel() == 0: + return vx + else: + return [vx, ax] + """Recombine audio and video latents for output.""" + # if ax.device != vx.device or ax.dtype != vx.dtype: + # logging.warning("Audio and video latents are on different devices or dtypes.") + # ax = ax.to(device=vx.device, dtype=vx.dtype) + # logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}") + # + # ax = ax.reshape(ax.shape[0], -1) + # # pad to f x h x w of the video latents + # divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3] + # if target_shape is None: + # repetitions = math.ceil(ax.shape[-1] / divisor) + # else: + # repetitions = target_shape[1] - vx.shape[1] + # padded_len = repetitions * divisor + # ax = F.pad(ax, (0, padded_len - ax.shape[-1])) + # ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1]) + # return torch.cat([vx, ax], dim=1) + + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXAV - separate audio and video, then patchify.""" + audio_length = kwargs.get("audio_length", 0) + # Separate audio and video latents + vx, ax = self.separate_audio_and_video_latents(x, audio_length) + + has_spatial_mask = False + if denoise_mask is not None: + # check if any frame has spatial variation (inpainting) + for frame_idx in range(denoise_mask.shape[2]): + frame_mask = denoise_mask[0, 0, frame_idx] + if frame_mask.numel() > 0 and frame_mask.min() != frame_mask.max(): + has_spatial_mask = True + break + + [vx, v_pixel_coords, additional_args] = super()._process_input( + vx, keyframe_idxs, denoise_mask, **kwargs + ) + additional_args["has_spatial_mask"] = has_spatial_mask + + ax, a_latent_coords = self.a_patchifier.patchify(ax) + + # Inject reference audio for ID-LoRA in-context conditioning + ref_audio = kwargs.get("ref_audio", None) + ref_audio_seq_len = 0 + if ref_audio is not None: + ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device) + if ref_tokens.shape[0] < ax.shape[0]: + ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1) + ref_audio_seq_len = ref_tokens.shape[1] + B = ax.shape[0] + + # Compute negative temporal positions matching ID-LoRA convention: + # offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0 + p = self.a_patchifier + tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate + ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device) + ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device) + time_offset = ref_end[-1].item() + tpl + ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) + ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1) + ref_pos = torch.stack([ref_start, ref_end], dim=-1) + + additional_args["ref_audio_seq_len"] = ref_audio_seq_len + additional_args["target_audio_seq_len"] = ax.shape[1] + ax = torch.cat([ref_tokens, ax], dim=1) + a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2) + + ax = self.audio_patchify_proj(ax) + + # additional_args.update({"av_orig_shape": list(x.shape)}) + return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + # TODO: some code reuse is needed here. + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep_scaled = timestep * self.timestep_scale_multiplier + + v_timestep, v_embedded_timestep = self.adaln_single( + timestep_scaled.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width] + # Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width + orig_shape = kwargs.get("orig_shape") + has_spatial_mask = kwargs.get("has_spatial_mask", None) + v_patches_per_frame = None + if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5: + # orig_shape[3] = height, orig_shape[4] = width (in latent space) + v_patches_per_frame = orig_shape[3] * orig_shape[4] + + # Reshape to [batch_size, num_tokens, dim] and compress for storage + v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) + v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + + v_prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + + # Prepare audio timestep + a_timestep = kwargs.get("a_timestep") + ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) + if ref_audio_seq_len > 0 and a_timestep is not None: + # Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma. + target_len = kwargs.get("target_audio_seq_len") + if a_timestep.dim() <= 1: + a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len) + ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype) + a_timestep = torch.cat([ref_ts, a_timestep], dim=1) + if a_timestep is not None: + a_timestep_scaled = a_timestep * self.timestep_scale_multiplier + a_timestep_flat = a_timestep_scaled.flatten() + timestep_flat = timestep_scaled.flatten() + av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + + # Cross-attention timesteps - compress these too + av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( + timestep.max().expand_as(a_timestep_flat), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( + a_timestep.max().expand_as(timestep_flat), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( + a_timestep.max().expand_as(timestep_flat) * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( + timestep.max().expand_as(a_timestep_flat) * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Compress cross-attention timesteps (only video side, audio is too small to benefit) + # v_patches_per_frame is None for spatial masks, set for temporal masks or no mask + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep.view(batch_size, -1, av_ca_audio_scale_shift_timestep.shape[-1]), + CompressedTimestep(av_ca_video_scale_shift_timestep.view(batch_size, -1, av_ca_video_scale_shift_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible + CompressedTimestep(av_ca_a2v_gate_noise_timestep.view(batch_size, -1, av_ca_a2v_gate_noise_timestep.shape[-1]), v_patches_per_frame), # video - compressed if possible + av_ca_v2a_gate_noise_timestep.view(batch_size, -1, av_ca_v2a_gate_noise_timestep.shape[-1]), + ] + + a_timestep, a_embedded_timestep = self.audio_adaln_single( + a_timestep_flat, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + # Audio timesteps + a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) + a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) + + a_prompt_timestep = compute_prompt_timestep( + self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype + ) + else: + a_timestep = timestep_scaled + a_embedded_timestep = kwargs.get("embedded_timestep") + cross_av_timestep_ss = [] + a_prompt_timestep = None + + return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [ + v_embedded_timestep, + a_embedded_timestep, + ], None + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + vx = x[0] + ax = x[1] + video_dim = vx.shape[-1] + audio_dim = ax.shape[-1] + + v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim + a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim + + v_context, a_context = torch.split( + context, [v_context_dim, a_context_dim], len(context.shape) - 1 + ) + + v_context, attention_mask = super()._prepare_context( + v_context, batch_size, vx, attention_mask + ) + if self.caption_proj_before_connector is False: + a_context = self.audio_caption_projection(a_context) + a_context = a_context.view(batch_size, -1, audio_dim) + + return [v_context, a_context], attention_mask + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + v_pixel_coords = pixel_coords[0] + v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype) + + a_latent_coords = pixel_coords[1] + a_pe = self._precompute_freqs_cis( + a_latent_coords, + dim=self.audio_inner_dim, + out_dtype=x_dtype, + max_pos=self.audio_positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.audio_num_attention_heads, + ) + + # calculate positional embeddings for the middle of the token duration, to use in av cross attention layers. + max_pos = max( + self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0] + ) + v_pixel_coords = v_pixel_coords.to(torch.float32) + v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate) + av_cross_video_freq_cis = self._precompute_freqs_cis( + v_pixel_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + av_cross_audio_freq_cis = self._precompute_freqs_cis( + a_latent_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + + return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] + + def _process_transformer_blocks( + self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs + ): + vx = x[0] + ax = x[1] + v_context = context[0] + a_context = context[1] + v_timestep = timestep[0] + a_timestep = timestep[1] + v_pe, av_cross_video_freq_cis = pe[0] + a_pe, av_cross_audio_freq_cis = pe[1] + + ( + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ) = timestep[2] + + v_prompt_timestep = timestep[3] + a_prompt_timestep = timestep[4] + + """Process transformer blocks for LTXAV.""" + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + # Process transformer blocks + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + + def block_wrap(args): + out = {} + out["img"] = block( + args["img"], + v_context=args["v_context"], + a_context=args["a_context"], + attention_mask=args["attention_mask"], + v_timestep=args["v_timestep"], + a_timestep=args["a_timestep"], + v_pe=args["v_pe"], + a_pe=args["a_pe"], + v_cross_pe=args["v_cross_pe"], + a_cross_pe=args["a_cross_pe"], + v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"], + a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"], + v_cross_gate_timestep=args["v_cross_gate_timestep"], + a_cross_gate_timestep=args["a_cross_gate_timestep"], + transformer_options=args["transformer_options"], + self_attention_mask=args.get("self_attention_mask"), + v_prompt_timestep=args.get("v_prompt_timestep"), + a_prompt_timestep=args.get("a_prompt_timestep"), + ) + return out + + out = blocks_replace[("double_block", i)]( + { + "img": (vx, ax), + "v_context": v_context, + "a_context": a_context, + "attention_mask": attention_mask, + "v_timestep": v_timestep, + "a_timestep": a_timestep, + "v_pe": v_pe, + "a_pe": a_pe, + "v_cross_pe": av_cross_video_freq_cis, + "a_cross_pe": av_cross_audio_freq_cis, + "v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep, + "a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep, + "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, + "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, + "transformer_options": transformer_options, + "self_attention_mask": self_attention_mask, + "v_prompt_timestep": v_prompt_timestep, + "a_prompt_timestep": a_prompt_timestep, + }, + {"original_block": block_wrap}, + ) + vx, ax = out["img"] + else: + vx, ax = block( + (vx, ax), + v_context=v_context, + a_context=a_context, + attention_mask=attention_mask, + v_timestep=v_timestep, + a_timestep=a_timestep, + v_pe=v_pe, + a_pe=a_pe, + v_cross_pe=av_cross_video_freq_cis, + a_cross_pe=av_cross_audio_freq_cis, + v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep, + a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep, + v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, + a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, + transformer_options=transformer_options, + self_attention_mask=self_attention_mask, + v_prompt_timestep=v_prompt_timestep, + a_prompt_timestep=a_prompt_timestep, + ) + + return [vx, ax] + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + vx = x[0] + ax = x[1] + v_embedded_timestep = embedded_timestep[0] + a_embedded_timestep = embedded_timestep[1] + + # Trim reference audio tokens before unpatchification + ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0) + if ref_audio_seq_len > 0: + ax = ax[:, ref_audio_seq_len:] + if a_embedded_timestep.shape[1] > 1: + a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:] + + # Expand compressed video timestep if needed + if isinstance(v_embedded_timestep, CompressedTimestep): + v_embedded_timestep = v_embedded_timestep.expand() + + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) + + # Process audio output + a_scale_shift_values = ( + self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype) + + a_embedded_timestep[:, :, None] + ) + a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1] + + ax = self.audio_norm_out(ax) + ax = ax * (1 + a_scale) + a_shift + ax = self.audio_proj_out(ax) + + # Unpatchify audio + ax = self.a_patchifier.unpatchify( + ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins + ) + + # Recombine audio and video + original_shape = kwargs.get("av_orig_shape") + return self.recombine_audio_and_video_latents(vx, ax, original_shape) + + def forward( + self, + x, + timestep, + context, + attention_mask=None, + frame_rate=25, + transformer_options={}, + keyframe_idxs=None, + **kwargs, + ): + """ + Forward pass for LTXAV model. + + Args: + x: Combined audio-video input tensor + timestep: Tuple of (video_timestep, audio_timestep) or single timestep + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments including audio_length + + Returns: + Combined audio-video output tensor + """ + # Handle timestep format + if isinstance(timestep, (tuple, list)) and len(timestep) == 2: + v_timestep, a_timestep = timestep + kwargs["a_timestep"] = a_timestep + timestep = v_timestep + else: + kwargs["a_timestep"] = timestep + + # Call parent forward method + return super().forward( + x, + timestep, + context, + attention_mask, + frame_rate, + transformer_options, + keyframe_idxs, + **kwargs, + ) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py new file mode 100644 index 000000000..2811080be --- /dev/null +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -0,0 +1,307 @@ +import math +from typing import Optional + +import comfy.ldm.common_dit +import torch +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + generate_freq_grid_np, + interleaved_freqs_cis, + split_freqs_cis, +) +from torch import nn + + +class BasicTransformerBlock1D(nn.Module): + r""" + A basic Transformer block. + + Parameters: + + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`. + ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer. + attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer. + use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE). + ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer. + """ + + def __init__( + self, + dim, + n_heads, + d_head, + context_dim=None, + attn_precision=None, + apply_gated_attention=False, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + apply_gated_attention=apply_gated_attention, + dtype=dtype, + device=device, + operations=operations, + ) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dim_out=dim, + glu=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor: + + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=128, + num_attention_heads=30, + num_layers=2, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[4096], + causal_temporal_positioning=False, + num_learnable_registers: Optional[int] = 128, + apply_gated_attention=False, + dtype=None, + device=None, + operations=None, + split_rope=False, + double_precision_rope=False, + **kwargs, + ): + super().__init__() + self.dtype = dtype + self.out_channels = in_channels + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_rope = split_rope + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = nn.ModuleList( + [ + BasicTransformerBlock1D( + self.inner_dim, + num_attention_heads, + attention_head_dim, + context_dim=cross_attention_dim, + apply_gated_attention=apply_gated_attention, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(num_layers) + ] + ) + + inner_dim = num_attention_heads * attention_head_dim + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = nn.Parameter( + torch.empty( + self.num_learnable_registers, inner_dim, dtype=dtype, device=device + ) + ) + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(1) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs(self, indices_grid, spacing): + source_dtype = indices_grid.dtype + dtype = ( + torch.float32 + if source_dtype in (torch.bfloat16, torch.float16) + else source_dtype + ) + + fractional_positions = self.get_fractional_positions(indices_grid) + indices = ( + generate_freq_grid_np( + self.positional_embedding_theta, + indices_grid.shape[1], + self.inner_dim, + ) + if self.double_precision_rope + else self.generate_freq_grid(spacing, dtype, fractional_positions.device) + ).to(device=fractional_positions.device) + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs + + def generate_freq_grid(self, spacing, dtype, device): + dim = self.inner_dim + theta = self.positional_embedding_theta + n_pos_dims = 1 + n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6 + start = 1 + end = theta + + if spacing == "exp": + indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem)) + indices = indices.to(dtype=dtype, device=device) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace( + start, end, dim // n_elem, device=device, dtype=dtype + ) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // n_elem, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + return indices + + def precompute_freqs_cis(self, indices_grid, spacing="exp", out_dtype=None): + dim = self.inner_dim + n_elem = 2 # 2 because of cos and sin + freqs = self.precompute_freqs(indices_grid, spacing) + if self.split_rope: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis( + freqs, pad_size, self.num_attention_heads + ) + else: + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(dtype=out_dtype), sin_freq.to(dtype=out_dtype), self.split_rope + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + + if self.num_learnable_registers: + num_registers_duplications = math.ceil( + max(1024, hidden_states.shape[1]) / self.num_learnable_registers + ) + learnable_registers = torch.tile( + self.learnable_registers.to(hidden_states), (num_registers_duplications, 1) + ) + + hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) + + if attention_mask is not None: + attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device) + + indices_grid = torch.arange( + hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device + ) + indices_grid = indices_grid[None, None, :] + freqs_cis = self.precompute_freqs_cis(indices_grid, out_dtype=hidden_states.dtype) + + # 2. Blocks + for block_idx, block in enumerate(self.transformer_1d_blocks): + hidden_states = block( + hidden_states, attention_mask=attention_mask, pe=freqs_cis + ) + + # 3. Output + # if self.output_scale is not None: + # hidden_states = hidden_states / self.output_scale + + hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + return hidden_states, attention_mask diff --git a/comfy/ldm/lightricks/latent_upsampler.py b/comfy/ldm/lightricks/latent_upsampler.py new file mode 100644 index 000000000..78ed7653f --- /dev/null +++ b/comfy/ldm/lightricks/latent_upsampler.py @@ -0,0 +1,292 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError( + f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}" + ) + return mapping[float(scale)] + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + + +class BlurDownsample(nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int): + super().__init__() + assert dims in (2, 3) + assert stride >= 1 and isinstance(stride, int) + self.dims = dims + self.stride = stride + + # 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized + k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (5,5) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + def _apply_2d(x2d: torch.Tensor) -> torch.Tensor: + # x2d: (B, C, H, W) + B, C, H, W = x2d.shape + weight = self.kernel.expand(C, 1, 5, 5) # depthwise + x2d = F.conv2d( + x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C + ) + return x2d + + if self.dims == 2: + return _apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = _apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + +class SpatialRationalResampler(nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = nn.Conv2d( + mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1 + ) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(nn.Module): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler( + mid_channels=mid_channels, scale=self.spatial_scale + ) + else: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + if isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + spatial_scale=config.get("spatial_scale", 2.0), + rational_resampler=config.get("rational_resampler", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + "spatial_scale": self.spatial_scale, + "rational_resampler": self.rational_resampler, + } diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 593f7940f..bfbc08357 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,13 +1,50 @@ +from abc import ABC, abstractmethod +from enum import Enum +import functools +import logging +import math +from typing import Dict, Optional, Tuple + +from einops import rearrange +import numpy as np import torch from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -import math -from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords -from comfy.ldm.flux.math import apply_rope1 + +logger = logging.getLogger(__name__) + +def _log_base(x, base): + return np.log(x) / np.log(base) + +class LTXRopeType(str, Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + KEY = "rope_type" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.INTERLEAVED + return cls(kwargs.get(cls.KEY, default)) + + +class LTXFrequenciesPrecision(str, Enum): + FLOAT32 = "float32" + FLOAT64 = "float64" + + KEY = "frequencies_precision" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.FLOAT32 + return cls(kwargs.get(cls.KEY, default)) + def get_timestep_embedding( timesteps: torch.Tensor, @@ -39,9 +76,7 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) @@ -73,7 +108,9 @@ class TimestepEmbedding(nn.Module): post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, - dtype=None, device=None, operations=None, + dtype=None, + device=None, + operations=None, ): super().__init__() @@ -90,7 +127,9 @@ class TimestepEmbedding(nn.Module): time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device + ) if post_act_fn is None: self.post_act = None @@ -139,12 +178,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ - def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, + embedding_dim, + size_emb_dim, + use_additional_conditions: bool = False, + dtype=None, + device=None, + operations=None, + ): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations + ) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) @@ -163,15 +212,22 @@ class AdaLayerNormSingle(nn.Module): use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ - def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None + ): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( - embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations + embedding_dim, + size_emb_dim=embedding_dim // 3, + use_additional_conditions=use_additional_conditions, + dtype=dtype, + device=device, + operations=operations, ) self.silu = nn.SiLU() - self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device) + self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device) def forward( self, @@ -185,6 +241,7 @@ class AdaLayerNormSingle(nn.Module): embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep + class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -192,18 +249,24 @@ class PixArtAlphaTextProjection(nn.Module): Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ - def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None): + def __init__( + self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None + ): super().__init__() if out_features is None: out_features = hidden_size - self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device) + self.linear_1 = operations.Linear( + in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device + ) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device + ) def forward(self, caption): hidden_states = self.linear_1(caption) @@ -212,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states +class NormSingleLinearTextProjection(nn.Module): + """Text projection for 20B models - single linear with RMSNorm (no activation).""" + + def __init__( + self, in_features, hidden_size, dtype=None, device=None, operations=None + ): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.in_norm = operations.RMSNorm( + in_features, eps=1e-6, elementwise_affine=False + ) + self.linear_1 = operations.Linear( + in_features, hidden_size, bias=True, dtype=dtype, device=device + ) + self.hidden_size = hidden_size + self.in_features = in_features + + def forward(self, caption): + caption = self.in_norm(caption) + caption = caption * (self.hidden_size / self.in_features) ** 0.5 + return self.linear_1(caption) + + class GELU_approx(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): super().__init__() @@ -222,23 +309,69 @@ class GELU_approx(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None): super().__init__() inner_dim = int(dim * mult) project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) + project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) ) def forward(self, x): return self.net(x) +def apply_rotary_emb(input_tensor, freqs_cis): + cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1] + split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False + return ( + apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs) + if split_pe else + apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs) + ) + +def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + +def apply_split_rotary_emb(input_tensor, cos, sin): + needs_reshape = False + if input_tensor.ndim != 4 and cos.ndim == 4: + B, H, T, _ = cos.shape + input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2) + needs_reshape = True + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + output = split_input * cos.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input) + output = rearrange(output, "... d r -> ... (d r)") + return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output + class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_precision=None, + apply_gated_attention=False, + dtype=None, + device=None, + operations=None, + ): super().__init__() inner_dim = dim_head * heads context_dim = query_dim if context_dim is None else context_dim @@ -254,9 +387,17 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) - self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + # Optional per-head gating + if apply_gated_attention: + self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device) + else: + self.to_gate_logits = None - def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -266,38 +407,79 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) - k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) + q = apply_rotary_emb(q, pe) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + + # Apply per-head gating if enabled + if self.to_gate_logits is not None: + gate_logits = self.to_gate_logits(x) # (B, T, H) + b, t, _ = out.shape + out = out.view(b, t, self.heads, self.dim_head) + gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity + out = out * gates.unsqueeze(-1) + out = out.view(b, t, self.heads * self.dim_head) + return self.to_out(out) +# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate) +ADALN_BASE_PARAMS_COUNT = 6 +ADALN_CROSS_ATTN_PARAMS_COUNT = 9 class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None + ): super().__init__() self.attn_precision = attn_precision - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.cross_attention_adaln = cross_attention_adaln + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) - self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) - attn1_input = comfy.ldm.common_dit.rms_norm(x) - attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) - x.addcmul_(attn1_input, gate_msa) - del attn1_input + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2) - x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa + + if self.cross_attention_adaln: + shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2) + x += apply_cross_attention_adaln( + x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca, + self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options, + ) + else: + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) @@ -305,117 +487,539 @@ class BasicTransformerBlock(nn.Module): return x +def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype): + """Compute a single global prompt timestep for cross-attention ADaLN. + + Uses the max across tokens (matching JAX max_per_segment) and broadcasts + over text tokens. Returns None when *adaln_module* is None. + """ + if adaln_module is None: + return None + ts_input = ( + timestep_scaled.max(dim=1, keepdim=True).values.flatten() + if timestep_scaled.dim() > 1 + else timestep_scaled.flatten() + ) + prompt_ts, _ = adaln_module( + ts_input, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1]) + + +def apply_cross_attention_adaln( + x, context, attn, q_shift, q_scale, q_gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask=None, transformer_options={}, +): + """Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV). + + Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so + that both regular tensors and CompressedTimestep are supported. + """ + batch_size = x.shape[0] + shift_kv, scale_kv = ( + prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) + ).unbind(dim=2) + attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift + encoder_hidden_states = context * (1 + scale_kv) + shift_kv + return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate + def get_fractional_positions(indices_grid, max_pos): + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' fractional_positions = torch.stack( - [ - indices_grid[:, i] / max_pos[i] - for i in range(3) - ], - dim=-1, + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + axis=-1, ) return fractional_positions -def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 - device = indices_grid.device +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None): + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + _log_base(start, theta), + _log_base(end, theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + +def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device): + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + device=device, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + +def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid): + if use_middle_indices_grid: + assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2) + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) - indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 + indices = indices.to(device=fractional_positions.device) - # Compute frequencies and apply cos/sin - freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) - cos_vals = freqs.cos().repeat_interleave(2, dim=-1) - sin_vals = freqs.sin().repeat_interleave(2, dim=-1) + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs - # Pad if dim is not divisible by 6 - if dim % 6 != 0: - padding_size = dim % 6 - cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) - sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) +def interleaved_freqs_cis(freqs, pad_size): + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq - # Reshape and extract one value per pair (since repeat_interleave duplicates each value) - cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] - sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] +def split_freqs_cis(freqs, pad_size, num_attention_heads): + cos_freq = freqs.cos() + sin_freq = freqs.sin() - # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension - freqs_cis = torch.stack([ - torch.stack([cos_vals, -sin_vals], dim=-1), - torch.stack([sin_vals, cos_vals], dim=-1) - ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) - return freqs_cis + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + # Reshape freqs to be compatible with multi-head attention + B , T, half_HD = cos_freq.shape -class LTXVModel(torch.nn.Module): - def __init__(self, - in_channels=128, - cross_attention_dim=2048, - attention_head_dim=64, - num_attention_heads=32, + cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) + sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) - caption_channels=4096, - num_layers=28, + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq +class LTXBaseModel(torch.nn.Module, ABC): + """ + Abstract base class for LTX models (Lightricks Transformer models). - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - causal_temporal_positioning=False, - vae_scale_factors=(8, 32, 32), - dtype=None, device=None, operations=None, **kwargs): + This class defines the common interface and shared functionality for all LTX models, + including LTXV (video) and LTXAV (audio-video) variants. + """ + + def __init__( + self, + in_channels: int, + cross_attention_dim: int, + attention_head_dim: int, + num_attention_heads: int, + caption_channels: int, + num_layers: int, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list = [20, 2048, 2048], + causal_temporal_positioning: bool = False, + vae_scale_factors: tuple = (8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, + caption_projection_first_linear=True, + dtype=None, + device=None, + operations=None, + **kwargs, + ): super().__init__() self.generator = None self.vae_scale_factors = vae_scale_factors + self.use_middle_indices_grid = use_middle_indices_grid self.dtype = dtype - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.cross_attention_dim = cross_attention_dim + self.attention_head_dim = attention_head_dim + self.num_attention_heads = num_attention_heads + self.caption_channels = caption_channels + self.num_layers = num_layers + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_positional_embedding = LTXRopeType.from_dict(kwargs) + self.freq_grid_generator = ( + generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64 + else generate_freq_grid_pytorch + ) self.causal_temporal_positioning = causal_temporal_positioning + self.operations = operations + self.timestep_scale_multiplier = timestep_scale_multiplier + self.caption_proj_before_connector = caption_proj_before_connector + self.cross_attention_adaln = cross_attention_adaln + self.caption_projection_first_linear = caption_projection_first_linear - self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + # Common dimensions + self.inner_dim = num_attention_heads * attention_head_dim + self.out_channels = in_channels + # Initialize common components + self._init_common_components(device, dtype) + + # Initialize model-specific components + self._init_model_components(device, dtype, **kwargs) + + # Initialize transformer blocks + self._init_transformer_blocks(device, dtype, **kwargs) + + # Initialize output components + self._init_output_components(device, dtype) + + def _init_common_components(self, device, dtype): + """Initialize components common to all LTX models + - patchify_proj: Linear projection for patchifying input + - adaln_single: AdaLN layer for timestep embedding + - caption_projection: Linear projection for caption embedding + """ + self.patchify_proj = self.operations.Linear( + self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device + ) + + embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations + self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - # self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device) + if self.cross_attention_adaln: + self.prompt_adaln_single = AdaLayerNormSingle( + self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + ) + else: + self.prompt_adaln_single = None - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.caption_projection = lambda a: a + else: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + @abstractmethod + def _init_model_components(self, device, dtype, **kwargs): + """Initialize model-specific components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_output_components(self, device, dtype): + """Initialize output components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input data. Must be implemented by subclasses.""" + pass + + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. + + Base implementation returns None (no attenuation). Subclasses that + support guide-based attention control should override this. + """ + return None + + @abstractmethod + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs): + """Process transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output data. Must be implemented by subclasses.""" + pass + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep_scaled = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln_single( + timestep_scaled.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + + prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + + return timestep, embedded_timestep, prompt_timestep + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + """Prepare context for transformer blocks.""" + if self.caption_proj_before_connector is False: + context = self.caption_projection(context) + + context = context.view(batch_size, -1, x.shape[-1]) + return context, attention_mask + + def _precompute_freqs_cis( + self, + indices_grid, + dim, + out_dtype, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=False, + num_attention_heads=32, + ): + split_mode = self.split_positional_embedding == LTXRopeType.SPLIT + indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if split_mode: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + """Prepare positional embeddings.""" + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + pe = self._precompute_freqs_cis( + fractional_coords, + dim=self.inner_dim, + out_dtype=x_dtype, + max_pos=self.positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + ) + return pe + + def _prepare_attention_mask(self, attention_mask, x_dtype): + """Prepare attention mask.""" + if attention_mask is not None and not torch.is_floating_point(attention_mask): + attention_mask = (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + return attention_mask + + def forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers( + comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options + ), + ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs) + + def _forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Internal forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + if isinstance(x, list): + input_dtype = x[0].dtype + batch_size = x[0].shape[0] + else: + input_dtype = x.dtype + batch_size = x.shape[0] + # Process input + merged_args = {**transformer_options, **kwargs} + x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args) + merged_args.update(additional_args) + + # Prepare timestep and context + timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + merged_args["prompt_timestep"] = prompt_timestep + context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) + + # Prepare attention mask and positional embeddings + attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) + pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + + # Build self-attention mask for per-guide attenuation + self_attention_mask = self._build_guide_self_attention_mask( + x, transformer_options, merged_args + ) + + # Process transformer blocks + x = self._process_transformer_blocks( + x, context, attention_mask, timestep, pe, + transformer_options=transformer_options, + self_attention_mask=self_attention_mask, + **merged_args, + ) + + # Process output + x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args) + return x + + +class LTXVModel(LTXBaseModel): + """LTXV model for video generation.""" + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=64, + num_attention_heads=32, + caption_channels=4096, + num_layers=28, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier=1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXV-specific components.""" + pass + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXV.""" self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, - num_attention_heads, - attention_head_dim, - context_dim=cross_attention_dim, - # attn_precision=attn_precision, - dtype=dtype, device=device, operations=operations + self.num_attention_heads, + self.attention_head_dim, + context_dim=self.cross_attention_dim, + cross_attention_adaln=self.cross_attention_adaln, + dtype=dtype, + device=device, + operations=self.operations, ) - for d in range(num_layers) + for _ in range(self.num_layers) ] ) + def _init_output_components(self, device, dtype): + """Initialize output components for LTXV.""" self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) - self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) - - self.patchifier = SymmetricPatchifier(1) - - def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._forward, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs) - - def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - patches_replace = transformer_options.get("patches_replace", {}) - - orig_shape = list(x.shape) + self.norm_out = self.operations.LayerNorm( + self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) + self.patchifier = SymmetricPatchifier(1, start_end=True) + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXV.""" + additional_args = {"orig_shape": list(x.shape)} x, latent_coords = self.patchifier.patchify(x) pixel_coords = latent_to_pixel_coords( latent_coords=latent_coords, @@ -423,50 +1027,267 @@ class LTXVModel(torch.nn.Module): causal_fix=self.causal_temporal_positioning, ) + grid_mask = None if keyframe_idxs is not None: - pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs + additional_args.update({ "orig_patchified_shape": list(x.shape)}) + denoise_mask = self.patchifier.patchify(denoise_mask)[0] + grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] + additional_args.update({"grid_mask": grid_mask}) + x = x[:, grid_mask, :] + pixel_coords = pixel_coords[:, :, grid_mask, ...] - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + + # Compute per-guide surviving token counts from guide_attention_entries. + # Each entry tracks one guide reference; they are appended in order and + # their pre_filter_counts partition the kf_grid_mask. + guide_entries = kwargs.get("guide_attention_entries", None) + if guide_entries: + total_pfc = sum(e["pre_filter_count"] for e in guide_entries) + if total_pfc != len(kf_grid_mask): + raise ValueError( + f"guide pre_filter_counts ({total_pfc}) != " + f"keyframe grid mask length ({len(kf_grid_mask)})" + ) + resolved_entries = [] + offset = 0 + for entry in guide_entries: + pfc = entry["pre_filter_count"] + entry_mask = kf_grid_mask[offset:offset + pfc] + surviving = int(entry_mask.sum().item()) + resolved_entries.append({ + **entry, + "surviving_count": surviving, + }) + offset += pfc + additional_args["resolved_guide_entries"] = resolved_entries + + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] + pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs + + # Total surviving guide tokens (all guides) + additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] x = self.patchify_proj(x) - timestep = timestep * 1000.0 + return x, pixel_coords, additional_args - if attention_mask is not None and not torch.is_floating_point(attention_mask): - attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max + def _build_guide_self_attention_mask(self, x, transformer_options, merged_args): + """Build self-attention mask for per-guide attention attenuation. - pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) + Reads resolved_guide_entries from merged_args (computed in _process_input) + to build a log-space additive bias mask that attenuates noisy ↔ guide + attention for each guide reference independently. - batch_size = x.shape[0] - timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=x.dtype, + Returns None if no attenuation is needed (all strengths == 1.0 and no + spatial masks, or no guide tokens). + """ + if isinstance(x, list): + # AV model: x = [vx, ax]; use vx for token count and device + total_tokens = x[0].shape[1] + device = x[0].device + dtype = x[0].dtype + else: + total_tokens = x.shape[1] + device = x.device + dtype = x.dtype + + num_guide_tokens = merged_args.get("num_guide_tokens", 0) + if num_guide_tokens == 0: + return None + + resolved_entries = merged_args.get("resolved_guide_entries", None) + if not resolved_entries: + return None + + # Check if any attenuation is actually needed + needs_attenuation = any( + e["strength"] < 1.0 or e.get("pixel_mask") is not None + for e in resolved_entries ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - timestep = timestep.view(batch_size, -1, timestep.shape[-1]) - embedded_timestep = embedded_timestep.view( - batch_size, -1, embedded_timestep.shape[-1] + if not needs_attenuation: + return None + + # Build per-guide-token weights for all tracked guide tokens. + # Guides are appended in order at the end of the sequence. + guide_start = total_tokens - num_guide_tokens + all_weights = [] + total_tracked = 0 + + for entry in resolved_entries: + surviving = entry["surviving_count"] + if surviving == 0: + continue + + strength = entry["strength"] + pixel_mask = entry.get("pixel_mask") + latent_shape = entry.get("latent_shape") + + if pixel_mask is not None and latent_shape is not None: + f_lat, h_lat, w_lat = latent_shape + per_token = self._downsample_mask_to_latent( + pixel_mask.to(device=device, dtype=dtype), + f_lat, h_lat, w_lat, + ) + # per_token shape: (B, f_lat*h_lat*w_lat). + # Collapse batch dim — the mask is assumed identical across the + # batch; validate and take the first element to get (1, tokens). + if per_token.shape[0] > 1: + ref = per_token[0] + for bi in range(1, per_token.shape[0]): + if not torch.equal(ref, per_token[bi]): + logger.warning( + "pixel_mask differs across batch elements; " + "using first element only." + ) + break + per_token = per_token[:1] + # `surviving` is the post-grid_mask token count. + # Clamp to surviving to handle any mismatch safely. + n_weights = min(per_token.shape[1], surviving) + weights = per_token[:, :n_weights] * strength # (1, n_weights) + else: + weights = torch.full( + (1, surviving), strength, device=device, dtype=dtype + ) + + all_weights.append(weights) + total_tracked += weights.shape[1] + + if not all_weights: + return None + + # Concatenate per-token weights for all tracked guides + tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) + + # Check if any weight is actually < 1.0 (otherwise no attenuation needed) + if (tracked_weights >= 1.0).all(): + return None + + # Build the mask: guide tokens are at the end of the sequence. + # Tracked guides come first (in order), untracked follow. + return self._build_self_attention_mask( + total_tokens, num_guide_tokens, total_tracked, + tracked_weights, guide_start, device, dtype, ) - # 2. Blocks - if self.caption_projection is not None: - batch_size = x.shape[0] - context = self.caption_projection(context) - context = context.view( - batch_size, -1, x.shape[-1] - ) + @staticmethod + def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): + """Downsample a pixel-space mask to per-token latent weights. + Args: + mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1]. + f_lat: Number of latent frames (pre-dilation original count). + h_lat: Latent height (pre-dilation original height). + w_lat: Latent width (pre-dilation original width). + + Returns: + (B, F_lat * H_lat * W_lat) flattened per-token weights. + """ + b = mask.shape[0] + f_pix = mask.shape[2] + + # Spatial downsampling: area interpolation per frame + spatial_down = torch.nn.functional.interpolate( + rearrange(mask, "b 1 f h w -> (b f) 1 h w"), + size=(h_lat, w_lat), + mode="area", + ) + spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b) + + # Temporal downsampling: first pixel frame maps to first latent frame, + # remaining pixel frames are averaged in groups for causal temporal structure. + first_frame = spatial_down[:, :, :1, :, :] + if f_pix > 1 and f_lat > 1: + remaining_pix = f_pix - 1 + remaining_lat = f_lat - 1 + t = remaining_pix // remaining_lat + if t < 1: + # Fewer pixel frames than latent frames — upsample by repeating + # the available pixel frames via nearest interpolation. + rest_flat = rearrange( + spatial_down[:, :, 1:, :, :], + "b 1 f h w -> (b h w) 1 f", + ) + rest_up = torch.nn.functional.interpolate( + rest_flat, size=remaining_lat, mode="nearest", + ) + rest = rearrange( + rest_up, "(b h w) 1 f -> b 1 f h w", + b=b, h=h_lat, w=w_lat, + ) + else: + # Trim trailing pixel frames that don't fill a complete group + usable = remaining_lat * t + rest = rearrange( + spatial_down[:, :, 1:1 + usable, :, :], + "b 1 (f t) h w -> b 1 f t h w", + t=t, + ) + rest = rest.mean(dim=3) + latent_mask = torch.cat([first_frame, rest], dim=2) + elif f_lat > 1: + # Single pixel frame but multiple latent frames — repeat the + # single frame across all latent frames. + latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1) + else: + latent_mask = first_frame + + return rearrange(latent_mask, "b 1 f h w -> b (f h w)") + + @staticmethod + def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count, + tracked_weights, guide_start, device, dtype): + """Build a log-space additive self-attention bias mask. + + Attenuates attention between noisy tokens and tracked guide tokens. + Untracked guide tokens (at the end of the guide portion) keep full attention. + + Args: + total_tokens: Total sequence length. + num_guide_tokens: Total guide tokens (all guides) at end of sequence. + tracked_count: Number of tracked guide tokens (first in the guide portion). + tracked_weights: (1, tracked_count) tensor, values in [0, 1]. + guide_start: Index where guide tokens begin in the sequence. + device: Target device. + dtype: Target dtype. + + Returns: + (1, 1, total_tokens, total_tokens) additive bias mask. + 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked. + """ + finfo = torch.finfo(dtype) + mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype) + tracked_end = guide_start + tracked_count + + # Convert weights to log-space bias + w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count) + log_w = torch.full_like(w, finfo.min) + positive_mask = w > 0 + if positive_mask.any(): + log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny)) + + # noisy → tracked guides: each noisy row gets the same per-guide weight + mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1) + # tracked guides → noisy: each guide row broadcasts its weight across noisy cols + mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1) + + return mask + + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): + """Process transformer blocks for LTXV.""" + patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prompt_timestep = kwargs.get("prompt_timestep", None) + for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: + def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -476,18 +1297,32 @@ class LTXVModel(torch.nn.Module): timestep=timestep, pe=pe, transformer_options=transformer_options, + self_attention_mask=self_attention_mask, + prompt_timestep=prompt_timestep, ) - # 3. Output + return x + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output for LTXV.""" + # Apply scale-shift modulation scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + x = self.norm_out(x) - # Modulation - x = torch.addcmul(x, x, scale).add_(shift) + x = x * (1 + scale) + shift x = self.proj_out(x) + if keyframe_idxs is not None: + grid_mask = kwargs["grid_mask"] + orig_patchified_shape = kwargs["orig_patchified_shape"] + full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) + full_x[:, grid_mask, :] = x + x = full_x + # Unpatchify to restore original dimensions + orig_shape = kwargs["orig_shape"] x = self.patchifier.unpatchify( latents=x, output_height=orig_shape[3], diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py index 4b9972b9f..8f9a41186 100644 --- a/comfy/ldm/lightricks/symmetric_patchifier.py +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -21,20 +21,23 @@ def latent_to_pixel_coords( Returns: Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. """ + shape = [1] * latent_coords.ndim + shape[1] = -1 pixel_coords = ( latent_coords - * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + * torch.tensor(scale_factors, device=latent_coords.device).view(*shape) ) if causal_fix: # Fix temporal scale for first frame to 1 due to causality - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) return pixel_coords class Patchifier(ABC): - def __init__(self, patch_size: int): + def __init__(self, patch_size: int, start_end: bool=False): super().__init__() self._patch_size = (1, patch_size, patch_size) + self.start_end = start_end @abstractmethod def patchify( @@ -71,11 +74,23 @@ class Patchifier(ABC): torch.arange(0, latent_width, self._patch_size[2], device=device), indexing="ij", ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size + latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0) + delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None] + latent_sample_coords_end = latent_sample_coords_start + delta + + latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_start = rearrange( + latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size ) + if self.start_end: + latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_end = rearrange( + latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size + ) + + latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1) + else: + latent_coords = latent_sample_coords_start return latent_coords @@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier): q=self._patch_size[2], ) return latents + + +class AudioPatchifier(Patchifier): + def __init__(self, patch_size: int, + sample_rate=16000, + hop_length=160, + audio_latent_downsample_factor=4, + is_causal=True, + start_end=False, + shift = 0 + ): + super().__init__(patch_size, start_end=start_end) + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + + def copy_with_shift(self, shift): + return AudioPatchifier( + self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor, + self.is_causal, self.start_end, shift + ) + + def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device): + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + if self.is_causal: + audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0) + return audio_mel_frame * self.hop_length / self.sample_rate + + + def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # audio_latents: (batch, channels, time, freq) + b, _, t, _ = audio_latents.shape + audio_latents = rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device) + audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + if self.start_end: + audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device) + audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1) + else: + audio_latents_timings = audio_latents_start_timings + return audio_latents, audio_latents_timings + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor: + # audio_latents: (batch, time, freq * channels) + audio_latents = rearrange( + audio_latents, "b t (c f) -> b c t f", c=channels, f=freq + ) + return audio_latents diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py new file mode 100644 index 000000000..fa0a00748 --- /dev/null +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -0,0 +1,282 @@ +import json +from dataclasses import dataclass +import math +import torch +import torchaudio + +import comfy.model_management +import comfy.model_patcher +import comfy.utils as utils +from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( + CausalityAxis, + CausalAudioAutoencoder, +) +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +@dataclass(frozen=True) +class AudioVAEComponentConfig: + """Container for model component configuration extracted from metadata.""" + + autoencoder: dict + vocoder: dict + + @classmethod + def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig": + assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE" + + raw_config = metadata["config"] + if isinstance(raw_config, str): + parsed_config = json.loads(raw_config) + else: + parsed_config = raw_config + + audio_config = parsed_config.get("audio_vae") + vocoder_config = parsed_config.get("vocoder") + + assert audio_config is not None, "Audio VAE config is required for audio VAE" + assert vocoder_config is not None, "Vocoder config is required for audio VAE" + + return cls(autoencoder=audio_config, vocoder=vocoder_config) + + +class ModelDeviceManager: + """Manages device placement and GPU residency for the composed model.""" + + def __init__(self, module: torch.nn.Module): + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.vae_offload_device() + self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) + + def ensure_model_loaded(self) -> None: + comfy.model_management.free_memory( + self.patcher.model_size(), + self.patcher.load_device, + ) + comfy.model_management.load_model_gpu(self.patcher) + + def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(self.patcher.load_device) + + @property + def load_device(self): + return self.patcher.load_device + + +class AudioLatentNormalizer: + """Applies per-channel statistics in patch space and restores original layout.""" + + def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module): + self.patchifier = patchfier + self.statistics = statistics_processor + + def normalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + normalized = self.statistics.normalize(patched) + return self.patchifier.unpatchify(normalized, channels=channels, freq=freq) + + def denormalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + denormalized = self.statistics.un_normalize(patched) + return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq) + + +class AudioPreprocessor: + """Prepares raw waveforms for the autoencoder by matching training conditions.""" + + def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int): + self.target_sample_rate = target_sample_rate + self.mel_bins = mel_bins + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + + def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor: + if source_rate == self.target_sample_rate: + return waveform + return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) + + def waveform_to_mel( + self, waveform: torch.Tensor, waveform_sample_rate: int, device + ) -> torch.Tensor: + waveform = self.resample(waveform, waveform_sample_rate) + + mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=self.target_sample_rate, + n_fft=self.n_fft, + win_length=self.n_fft, + hop_length=self.mel_hop_length, + f_min=0.0, + f_max=self.target_sample_rate / 2.0, + n_mels=self.mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ).to(device) + + mel = mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + return mel.permute(0, 1, 3, 2).contiguous() + + +class AudioVAE(torch.nn.Module): + """High-level Audio VAE wrapper exposing encode and decode entry points.""" + + def __init__(self, state_dict: dict, metadata: dict): + super().__init__() + + component_config = AudioVAEComponentConfig.from_metadata(metadata) + + vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) + vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) + + self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) + if "bwe" in component_config.vocoder: + self.vocoder = VocoderWithBWE(config=component_config.vocoder) + else: + self.vocoder = Vocoder(config=component_config.vocoder) + + self.autoencoder.load_state_dict(vae_sd, strict=False) + self.vocoder.load_state_dict(vocoder_sd, strict=False) + + autoencoder_config = self.autoencoder.get_config() + self.normalizer = AudioLatentNormalizer( + AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=autoencoder_config["sampling_rate"], + hop_length=autoencoder_config["mel_hop_length"], + is_causal=autoencoder_config["is_causal"], + ), + self.autoencoder.per_channel_statistics, + ) + + self.preprocessor = AudioPreprocessor( + target_sample_rate=autoencoder_config["sampling_rate"], + mel_bins=autoencoder_config["mel_bins"], + mel_hop_length=autoencoder_config["mel_hop_length"], + n_fft=autoencoder_config["n_fft"], + ) + + self.device_manager = ModelDeviceManager(self) + + def encode(self, audio: dict) -> torch.Tensor: + """Encode a waveform dictionary into normalized latent tensors.""" + + waveform = audio["waveform"] + waveform_sample_rate = audio["sample_rate"] + input_device = waveform.device + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + waveform = self.device_manager.move_to_load_device(waveform) + expected_channels = self.autoencoder.encoder.in_channels + if waveform.shape[1] != expected_channels: + if waveform.shape[1] == 1: + waveform = waveform.expand(-1, expected_channels, *waveform.shape[2:]) + else: + raise ValueError( + f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" + ) + + mel_spec = self.preprocessor.waveform_to_mel( + waveform, waveform_sample_rate, device=self.device_manager.load_device + ) + + latents = self.autoencoder.encode(mel_spec) + posterior = DiagonalGaussianDistribution(latents) + latent_mode = posterior.mode() + + normalized = self.normalizer.normalize(latent_mode) + return normalized.to(input_device) + + def decode(self, latents: torch.Tensor) -> torch.Tensor: + """Decode normalized latent tensors into an audio waveform.""" + original_shape = latents.shape + + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + latents = self.device_manager.move_to_load_device(latents) + latents = self.normalizer.denormalize(latents) + + target_shape = self.target_shape_from_latents(original_shape) + mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) + + waveform = self.run_vocoder(mel_spec) + return self.device_manager.move_to_load_device(waveform) + + def target_shape_from_latents(self, latents_shape): + batch, _, time, _ = latents_shape + target_length = time * LATENT_DOWNSAMPLE_FACTOR + if self.autoencoder.causality_axis != CausalityAxis.NONE: + target_length -= LATENT_DOWNSAMPLE_FACTOR - 1 + return ( + batch, + self.autoencoder.decoder.out_ch, + target_length, + self.autoencoder.mel_bins, + ) + + def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int: + return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second) + + def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor: + audio_channels = self.autoencoder.decoder.out_ch + vocoder_input = mel_spec.transpose(2, 3) + + if audio_channels == 1: + vocoder_input = vocoder_input.squeeze(1) + elif audio_channels != 2: + raise ValueError(f"Unsupported audio_channels: {audio_channels}") + + return self.vocoder(vocoder_input) + + @property + def sample_rate(self) -> int: + return int(self.autoencoder.sampling_rate) + + @property + def mel_hop_length(self) -> int: + return int(self.autoencoder.mel_hop_length) + + @property + def mel_bins(self) -> int: + return int(self.autoencoder.mel_bins) + + @property + def latent_channels(self) -> int: + return int(self.autoencoder.decoder.z_channels) + + @property + def latent_frequency_bins(self) -> int: + return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR) + + @property + def latents_per_second(self) -> float: + return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR + + @property + def output_sample_rate(self) -> int: + output_rate = getattr(self.vocoder, "output_sample_rate", None) + if output_rate is not None: + return int(output_rate) + upsample_factor = getattr(self.vocoder, "upsample_factor", None) + if upsample_factor is None: + raise AttributeError( + "Vocoder is missing upsample_factor; cannot infer output sample rate" + ) + return int(self.sample_rate * upsample_factor / self.mel_hop_length) + + def memory_required(self, input_shape): + return self.device_manager.patcher.model_size() diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py new file mode 100644 index 000000000..b556b128f --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -0,0 +1,900 @@ +from __future__ import annotations +import torch +from torch import nn +from torch.nn import functional as F +from typing import Optional +from enum import Enum +from .pixel_norm import PixelNorm +import comfy.ops +import logging + +ops = comfy.ops.disable_weight_init + + +class StringConvertibleEnum(Enum): + """ + Base enum class that provides string-to-enum conversion functionality. + + This mixin adds a str_to_enum() class method that handles conversion from + strings, None, or existing enum instances with case-insensitive matching. + """ + + @classmethod + def str_to_enum(cls, value): + """ + Convert a string, enum instance, or None to the appropriate enum member. + + Args: + value: Can be an enum instance of this class, a string, or None + + Returns: + Enum member of this class + + Raises: + ValueError: If the value cannot be converted to a valid enum member + """ + # Already an enum instance of this class + if isinstance(value, cls): + return value + + # None maps to NONE member if it exists + if value is None: + if hasattr(cls, "NONE"): + return cls.NONE + raise ValueError(f"{cls.__name__} does not have a NONE member to map None to") + + # String conversion (case-insensitive) + if isinstance(value, str): + value_lower = value.lower() + + # Try to match against enum values + for member in cls: + # Handle members with None values + if member.value is None: + if value_lower == "none": + return member + # Handle members with string values + elif isinstance(member.value, str) and member.value.lower() == value_lower: + return member + + # Build helpful error message with valid values + valid_values = [] + for member in cls: + if member.value is None: + valid_values.append("none") + elif isinstance(member.value, str): + valid_values.append(member.value) + + raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}") + + raise ValueError( + f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. " + f"Expected string, None, or {cls.__name__} instance." + ) + + +class AttentionType(StringConvertibleEnum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class CausalityAxis(StringConvertibleEnum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +def Normalize(in_channels, *, num_groups=32, normtype="group"): + if normtype == "group": + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif normtype == "pixel": + return PixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {normtype}") + + +class CausalConv2d(nn.Module): + """ + A causal 2D convolution. + + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = nn.modules.utils._pair(kernel_size) + dilation = nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=None, + dilation=1, + groups=1, + bias=True, + causality_axis: Optional[CausalityAxis] = None, +): + """ + Create a 2D convolution layer that can be either causal or non-causal. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + if isinstance(kernel_size, int): + padding = kernel_size // 2 + else: + padding = tuple(k // 2 for k in kernel_size) + return ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + # (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + norm_type="group", + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, normtype=norm_type) + self.non_linearity = nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = ops.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels, normtype=norm_type) + self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla", norm_type="group"): + # Convert string to enum if needed + attn_type = AttentionType.str_to_enum(attn_type) + + if attn_type != AttentionType.NONE: + logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels") + else: + logging.info(f"making identity attention with {in_channels} in_channels") + + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return nn.Identity(in_channels) + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + self.non_linearity = nn.SiLU() + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + def forward(self, x): + """ + Forward pass through the encoder. + + Args: + x: Input tensor of shape [batch, channels, time, n_mels] + + Returns: + Encoded latent representation + """ + feature_maps = [self.conv_in(x)] + + # Process each resolution level (from high to low resolution) + for resolution_level in range(self.num_resolutions): + # Apply residual blocks at current resolution level + for block_idx in range(self.num_res_blocks): + # Apply ResNet block with optional timestep embedding + current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None) + + # Apply attention if configured for this resolution level + if len(self.down[resolution_level].attn) > 0: + current_features = self.down[resolution_level].attn[block_idx](current_features) + + # Store processed features + feature_maps.append(current_features) + + # Downsample spatial dimensions (except at the final resolution level) + if resolution_level != self.num_resolutions - 1: + downsampled_features = self.down[resolution_level].downsample(feature_maps[-1]) + feature_maps.append(downsampled_features) + + # === MIDDLE PROCESSING PHASE === + # Take the lowest resolution features for middle processing + bottleneck_features = feature_maps[-1] + + # Apply first middle ResNet block + bottleneck_features = self.mid.block_1(bottleneck_features, temb=None) + + # Apply middle attention block + bottleneck_features = self.mid.attn_1(bottleneck_features) + + # Apply second middle ResNet block + bottleneck_features = self.mid.block_2(bottleneck_features, temb=None) + + # === OUTPUT PHASE === + # Normalize the bottleneck features + output_features = self.norm_out(bottleneck_features) + + # Apply non-linearity (SiLU activation) + output_features = self.non_linearity(output_features) + + # Final convolution to produce latent representation + # [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels] + return self.conv_out(output_features) + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = out_ch + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.norm_type = norm_type + self.z_channels = z_channels + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis) + + self.non_linearity = nn.SiLU() + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis) + + def _adjust_output_shape(self, decoded_output, target_shape): + """ + Adjust output shape to match target dimensions for variable-length audio. + + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: Target shape tuple (batch, channels, time, frequency) + + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + _, target_channels, target_time, target_freq = target_shape + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def get_config(self): + return { + "ch": self.ch, + "out_ch": self.out_ch, + "ch_mult": self.ch_mult, + "num_res_blocks": self.num_res_blocks, + "in_channels": self.in_channels, + "resolution": self.resolution, + "z_channels": self.z_channels, + } + + def forward(self, latent_features, target_shape=None): + """ + Decode latent features back to audio spectrograms. + + Args: + latent_features: Encoded latent representation of shape (batch, channels, height, width) + target_shape: Optional target output shape (batch, channels, time, frequency) + If provided, output will be cropped/padded to match this shape + + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder" + + # Transform latent features to decoder's internal feature dimension + hidden_features = self.conv_in(latent_features) + + # Middle processing + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + # Upsampling + # Progressively increase spatial resolution from lowest to highest + for resolution_level in reversed(range(self.num_resolutions)): + # Apply residual blocks at current resolution level + for block_index in range(self.num_res_blocks + 1): + hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None) + + if len(self.up[resolution_level].attn) > 0: + hidden_features = self.up[resolution_level].attn[block_index](hidden_features) + + if resolution_level != 0: + hidden_features = self.up[resolution_level].upsample(hidden_features) + + # Output + if self.give_pre_end: + # Return intermediate features before final processing (for debugging/analysis) + decoded_output = hidden_features + else: + # Standard output path: normalize, activate, and convert to output channels + # Final normalization layer + hidden_features = self.norm_out(hidden_features) + + # Apply SiLU (Swish) activation function + hidden_features = self.non_linearity(hidden_features) + + # Final convolution to map to output channels (typically 2 for stereo audio) + decoded_output = self.conv_out(hidden_features) + + # Optional tanh activation to bound output values to [-1, 1] range + if self.tanh_out: + decoded_output = torch.tanh(decoded_output) + + # Adjust shape for audio data + if target_shape is not None: + decoded_output = self._adjust_output_shape(decoded_output, target_shape) + + return decoded_output + + +class processor(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("std-of-means", torch.empty(128)) + self.register_buffer("mean-of-means", torch.empty(128)) + + def un_normalize(self, x): + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x): + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +class CausalAudioAutoencoder(nn.Module): + def __init__(self, config=None): + super().__init__() + + if config is None: + config = self.get_default_config() + + model_config = config.get("model", {}).get("params", {}) + + self.sampling_rate = model_config.get( + "sampling_rate", config.get("sampling_rate", 16000) + ) + encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) + decoder_config = model_config.get("decoder", encoder_config) + + # Load mel spectrogram parameters + self.mel_bins = encoder_config.get("mel_bins", 64) + self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + + # Store causality configuration at VAE level (not just in encoder internals) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value) + self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) + self.is_causal = self.causality_axis == CausalityAxis.HEIGHT + + self.encoder = Encoder(**encoder_config) + self.decoder = Decoder(**decoder_config) + + self.per_channel_statistics = processor() + + def get_default_config(self): + ddconfig = { + "double_z": True, + "mel_bins": 64, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 2, + "out_ch": 2, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + "mid_block_add_attention": False, + "norm_type": "pixel", + "causality_axis": "height", + } + + config = { + "model": { + "params": { + "ddconfig": ddconfig, + "sampling_rate": 16000, + } + }, + "preprocessing": { + "stft": { + "filter_length": 1024, + "hop_length": 160, + }, + }, + } + + return config + + def get_config(self): + return { + "sampling_rate": self.sampling_rate, + "mel_bins": self.mel_bins, + "mel_hop_length": self.mel_hop_length, + "n_fft": self.n_fft, + "causality_axis": self.causality_axis.value, + "is_causal": self.is_causal, + } + + def encode(self, x): + return self.encoder(x) + + def decode(self, x, target_shape=None): + return self.decoder(x, target_shape=target_shape) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index 70d612e86..7515f0d4e 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -1,11 +1,11 @@ from typing import Tuple, Union +import threading import torch import torch.nn as nn import comfy.ops ops = comfy.ops.disable_weight_init - class CausalConv3d(nn.Module): def __init__( self, @@ -23,6 +23,11 @@ class CausalConv3d(nn.Module): self.in_channels = in_channels self.out_channels = out_channels + if isinstance(stride, int): + self.time_stride = stride + else: + self.time_stride = stride[0] + kernel_size = (kernel_size, kernel_size, kernel_size) self.time_kernel_size = kernel_size[0] @@ -42,23 +47,43 @@ class CausalConv3d(nn.Module): padding_mode=spatial_padding_mode, groups=groups, ) + self.temporal_cache_state={} def forward(self, x, causal: bool = True): - if causal: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, self.time_kernel_size - 1, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x), dim=2) - else: - first_frame_pad = x[:, :, :1, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - last_frame_pad = x[:, :, -1:, :, :].repeat( - (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) - ) - x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) - x = self.conv(x) - return x + tid = threading.get_ident() + + cached, is_end = self.temporal_cache_state.get(tid, (None, False)) + if cached is None: + padding_length = self.time_kernel_size - 1 + if not causal: + padding_length = padding_length // 2 + if x.shape[2] == 0: + return x + cached = x[:, :, :1, :, :].repeat((1, 1, padding_length, 1, 1)) + pieces = [ cached, x ] + if is_end and not causal: + pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) + input_length = sum([piece.shape[2] for piece in pieces]) + cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride) + + needs_caching = not is_end + if needs_caching and cache_length == 0: + self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False) + needs_caching = False + if needs_caching and x.shape[2] >= cache_length: + needs_caching = False + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) + + x = torch.cat(pieces, dim=2) + del pieces + del cached + + if needs_caching: + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) + elif is_end: + self.temporal_cache_state[tid] = (None, True) + + return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :] @property def weight(self): diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 75ed069ad..998122c85 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -1,4 +1,5 @@ from __future__ import annotations +import threading import torch from torch import nn from functools import partial @@ -6,12 +7,39 @@ import math from einops import rearrange from typing import List, Optional, Tuple, Union from .conv_nd_factory import make_conv_nd, make_linear_nd +from .causal_conv3d import CausalConv3d from .pixel_norm import PixelNorm from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings import comfy.ops +import comfy.model_management +from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def in_meta_context(): + return torch.device("meta") == torch.empty(0).device + +def mark_conv3d_ended(module): + tid = threading.get_ident() + for _, m in module.named_modules(): + if isinstance(m, CausalConv3d): + current = m.temporal_cache_state.get(tid, (None, False)) + m.temporal_cache_state[tid] = (current[0], True) + +def split2(tensor, split_point, dim=2): + return torch.split(tensor, [split_point, tensor.shape[dim] - split_point], dim=dim) + +def add_exchange_cache(dest, cache_in, new_input, dim=2): + if dest is not None: + if cache_in is not None: + cache_to_dest = min(dest.shape[dim], cache_in.shape[dim]) + lead_in_dest, dest = split2(dest, cache_to_dest, dim=dim) + lead_in_source, cache_in = split2(cache_in, cache_to_dest, dim=dim) + lead_in_dest.add_(lead_in_source) + body, new_input = split2(new_input, dest.shape[dim], dim) + dest.add_(body) + return torch_cat_if_needed([cache_in, new_input], dim=dim) + class Encoder(nn.Module): r""" The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. @@ -205,10 +233,7 @@ class Encoder(nn.Module): self.gradient_checkpointing = False - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]: sample = self.conv_in(sample) checkpoint_fn = ( @@ -219,10 +244,14 @@ class Encoder(nn.Module): for down_block in self.down_blocks: sample = checkpoint_fn(down_block)(sample) + if sample is None or sample.shape[2] == 0: + return None sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) + if sample is None or sample.shape[2] == 0: + return None if self.latent_log_var == "uniform": last_channel = sample[:, -1:, ...] @@ -254,6 +283,64 @@ class Encoder(nn.Module): return sample + def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder + frame_size = sample[:, :, :1, :, :].numel() * sample.element_size() + frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels)) + + outputs = [] + samples = [sample[:, :, :1, :, :]] + if sample.shape[2] > 1: + chunk_t = max(2, max_chunk_size // frame_size) + if chunk_t < 4: + chunk_t = 2 + elif chunk_t < 8: + chunk_t = 4 + else: + chunk_t = (chunk_t // 8) * 8 + samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2)) + for chunk_idx, chunk in enumerate(samples): + if chunk_idx == len(samples) - 1: + mark_conv3d_ended(self) + chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device) + output = self._forward_chunk(chunk) + if output is not None: + outputs.append(output) + + return torch_cat_if_needed(outputs, dim=2) + + def forward(self, *args, **kwargs): + try: + return self.forward_orig(*args, **kwargs) + finally: + tid = threading.get_ident() + for _, module in self.named_modules(): + # ComfyUI doesn't thread this kind of stuff today, but just in case + # we key on the thread to make it thread safe. + tid = threading.get_ident() + if hasattr(module, "temporal_cache_state"): + module.temporal_cache_state.pop(tid, None) + + +MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3 +MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3 +MIN_CHUNK_SIZE = 32 * 1024 ** 2 +MAX_CHUNK_SIZE = 128 * 1024 ** 2 + +def get_max_chunk_size(device: torch.device) -> int: + total_memory = comfy.model_management.get_total_memory(dev=device) + + if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING: + return MIN_CHUNK_SIZE + if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING: + return MAX_CHUNK_SIZE + + interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / ( + MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING + ) + return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE)) class Decoder(nn.Module): r""" @@ -310,6 +397,10 @@ class Decoder(nn.Module): output_channel = output_channel * block_params.get("multiplier", 2) if block_name == "compress_all": output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_space": + output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_time": + output_channel = output_channel * block_params.get("multiplier", 1) self.conv_in = make_conv_nd( dims, @@ -341,18 +432,6 @@ class Decoder(nn.Module): timestep_conditioning=timestep_conditioning, spatial_padding_mode=spatial_padding_mode, ) - elif block_name == "attn_res_x": - block = UNetMidBlock3D( - dims=dims, - in_channels=input_channel, - num_layers=block_params["num_layers"], - resnet_groups=norm_num_groups, - norm_layer=norm_layer, - inject_noise=block_params.get("inject_noise", False), - timestep_conditioning=timestep_conditioning, - attention_head_dim=block_params["attention_head_dim"], - spatial_padding_mode=spatial_padding_mode, - ) elif block_name == "res_x_y": output_channel = output_channel // block_params.get("multiplier", 2) block = ResnetBlock3D( @@ -367,17 +446,21 @@ class Decoder(nn.Module): spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(2, 1, 1), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(1, 2, 2), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": @@ -417,6 +500,17 @@ class Decoder(nn.Module): self.gradient_checkpointing = False + # Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset) + ts, hs, ws, to = 1, 1, 1, 0 + for block in self.up_blocks: + if isinstance(block, DepthToSpaceUpsample): + ts *= block.stride[0] + hs *= block.stride[1] + ws *= block.stride[2] + if block.stride[0] > 1: + to = to * block.stride[0] + 1 + self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to) + self.timestep_conditioning = timestep_conditioning if timestep_conditioning: @@ -427,16 +521,78 @@ class Decoder(nn.Module): output_channel * 2, 0, operations=ops, ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + else: + self.register_buffer( + "last_scale_shift_table", + torch.tensor( + [0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(2, output_channel), + persistent=False, + ) - # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: - def forward( + + def decode_output_shape(self, input_shape): + c, (ts, hs, ws), to = self._output_scale + return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) + + def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size): + sample = sample_ref[0] + sample_ref[0] = None + if idx >= len(self.up_blocks): + sample = self.conv_norm_out(sample) + if timestep_shift_scale is not None: + shift, scale = timestep_shift_scale + sample = sample * (1 + scale) + shift + sample = self.conv_act(sample) + if ended: + mark_conv3d_ended(self.conv_out) + sample = self.conv_out(sample, causal=self.causal) + if sample is not None and sample.shape[2] > 0: + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + t = sample.shape[2] + output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) + output_offset[0] += t + return + + up_block = self.up_blocks[idx] + if ended: + mark_conv3d_ended(up_block) + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + if sample is None or sample.shape[2] == 0: + return + + total_bytes = sample.numel() * sample.element_size() + num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size + + if num_chunks == 1: + # when we are not chunking, detach our x so the callee can free it as soon as they are done + next_sample_ref = [sample] + del sample + self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + return + else: + samples = torch.chunk(sample, chunks=num_chunks, dim=2) + + for chunk_idx, sample1 in enumerate(samples): + self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + + def forward_orig( self, sample: torch.FloatTensor, timestep: Optional[torch.Tensor] = None, + output_buffer: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] + mark_conv3d_ended(self.conv_in) sample = self.conv_in(sample, causal=self.causal) checkpoint_fn = ( @@ -445,24 +601,13 @@ class Decoder(nn.Module): else lambda x: x ) + timestep_shift_scale = None scaled_timestep = None if self.timestep_conditioning: assert ( timestep is not None ), "should pass timestep with timestep_conditioning=True" scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device) - - for up_block in self.up_blocks: - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) - - sample = self.conv_norm_out(sample) - - if self.timestep_conditioning: embedded_timestep = self.last_time_embedder( timestep=scaled_timestep.flatten(), resolution=None, @@ -483,15 +628,31 @@ class Decoder(nn.Module): embedded_timestep.shape[-2], embedded_timestep.shape[-1], ) - shift, scale = ada_values.unbind(dim=1) - sample = sample * (1 + scale) + shift + timestep_shift_scale = ada_values.unbind(dim=1) - sample = self.conv_act(sample) - sample = self.conv_out(sample, causal=self.causal) + if output_buffer is None: + output_buffer = torch.empty( + self.decode_output_shape(sample.shape), + dtype=sample.dtype, device=comfy.model_management.intermediate_device(), + ) + output_offset = [0] - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + max_chunk_size = get_max_chunk_size(sample.device) - return sample + self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + + return output_buffer + + def forward(self, *args, **kwargs): + try: + return self.forward_orig(*args, **kwargs) + finally: + for _, module in self.named_modules(): + #ComfyUI doesn't thread this kind of stuff today, but just incase + #we key on the thread to make it thread safe. + tid = threading.get_ident() + if hasattr(module, "temporal_cache_state"): + module.temporal_cache_state.pop(tid, None) class UNetMidBlock3D(nn.Module): @@ -604,12 +765,25 @@ class SpaceToDepthDownsample(nn.Module): causal=True, spatial_padding_mode=spatial_padding_mode, ) + self.temporal_cache_state = {} def forward(self, x, causal: bool = True): - if self.stride[0] == 2: + tid = threading.get_ident() + cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None)) + if cached_input is not None: + x = torch_cat_if_needed([cached_input, x], dim=2) + cached_input = None + + if self.stride[0] == 2 and pad_first: x = torch.cat( [x[:, :, :1, :, :], x], dim=2 ) # duplicate first frames for padding + pad_first = False + + if x.shape[2] < self.stride[0]: + cached_input = x + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) + return None # skip connection x_in = rearrange( @@ -624,15 +798,26 @@ class SpaceToDepthDownsample(nn.Module): # conv x = self.conv(x, causal=causal) - x = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) + if self.stride[0] == 2 and x.shape[2] == 1: + if cached_x is not None: + x = torch_cat_if_needed([cached_x, x], dim=2) + cached_x = None + else: + cached_x = x + x = None - x = x + x_in + if x is not None: + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + cached = add_exchange_cache(x, cached, x_in, dim=2) + + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) return x @@ -663,8 +848,22 @@ class DepthToSpaceUpsample(nn.Module): ) self.residual = residual self.out_channels_reduction_factor = out_channels_reduction_factor + self.temporal_cache_state = {} def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None): + tid = threading.get_ident() + cached, drop_first_conv, drop_first_res = self.temporal_cache_state.get(tid, (None, True, True)) + y = self.conv(x, causal=causal) + y = rearrange( + y, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2 and y.shape[2] > 0 and drop_first_conv: + y = y[:, :, 1:, :, :] + drop_first_conv = False if self.residual: # Reshape and duplicate the input to match the output shape x_in = rearrange( @@ -676,21 +875,20 @@ class DepthToSpaceUpsample(nn.Module): ) num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor x_in = x_in.repeat(1, num_repeat, 1, 1, 1) - if self.stride[0] == 2: + if self.stride[0] == 2 and x_in.shape[2] > 0 and drop_first_res: x_in = x_in[:, :, 1:, :, :] - x = self.conv(x, causal=causal) - x = rearrange( - x, - "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) - if self.stride[0] == 2: - x = x[:, :, 1:, :, :] - if self.residual: - x = x + x_in - return x + drop_first_res = False + + if y.shape[2] == 0: + y = None + + cached = add_exchange_cache(y, cached, x_in, dim=2) + self.temporal_cache_state[tid] = (cached, drop_first_conv, drop_first_res) + + else: + self.temporal_cache_state[tid] = (None, drop_first_conv, False) + + return y class LayerNorm(nn.Module): def __init__(self, dim, eps, elementwise_affine=True) -> None: @@ -806,6 +1004,17 @@ class ResnetBlock3D(nn.Module): self.scale_shift_table = nn.Parameter( torch.randn(4, in_channels) / in_channels**0.5 ) + else: + self.register_buffer( + "scale_shift_table", + torch.tensor( + [0.0, 0.0, 0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(4, in_channels), + persistent=False, + ) + + self.temporal_cache_state={} def _feed_spatial_noise( self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor @@ -880,9 +1089,12 @@ class ResnetBlock3D(nn.Module): input_tensor = self.conv_shortcut(input_tensor) - output_tensor = input_tensor + hidden_states + tid = threading.get_ident() + cached = self.temporal_cache_state.get(tid, None) + cached = add_exchange_cache(hidden_states, cached, input_tensor, dim=2) + self.temporal_cache_state[tid] = cached - return output_tensor + return hidden_states def patchify(x, patch_size_hw, patch_size_t=1): @@ -930,9 +1142,6 @@ class processor(nn.Module): super().__init__() self.register_buffer("std-of-means", torch.empty(128)) self.register_buffer("mean-of-means", torch.empty(128)) - self.register_buffer("mean-of-stds", torch.empty(128)) - self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128)) - self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) @@ -941,13 +1150,18 @@ class processor(nn.Module): return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): + comfy_has_chunked_io = True + def __init__(self, version=0, config=None): super().__init__() if config is None: - config = self.guess_config(version) + config = self.get_default_config(version) + self.config = config self.timestep_conditioning = config.get("timestep_conditioning", False) + self.decode_noise_scale = config.get("decode_noise_scale", 0.025) + self.decode_timestep = config.get("decode_timestep", 0.05) double_z = config.get("double_z", True) latent_log_var = config.get( "latent_log_var", "per_channel" if double_z else "none" @@ -962,6 +1176,7 @@ class VideoVAE(nn.Module): latent_log_var=latent_log_var, norm_layer=config.get("norm_layer", "group_norm"), spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + base_channels=config.get("encoder_base_channels", 128), ) self.decoder = Decoder( @@ -969,6 +1184,7 @@ class VideoVAE(nn.Module): in_channels=config["latent_channels"], out_channels=config.get("out_channels", 3), blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), + base_channels=config.get("decoder_base_channels", 128), patch_size=config.get("patch_size", 1), norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), @@ -978,7 +1194,7 @@ class VideoVAE(nn.Module): self.per_channel_statistics = processor() - def guess_config(self, version): + def get_default_config(self, version): if version == 0: config = { "_class_name": "CausalVideoAutoencoder", @@ -1078,15 +1294,15 @@ class VideoVAE(nn.Module): } return config - def encode(self, x): - frames_count = x.shape[2] - if ((frames_count - 1) % 8) != 0: - raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.") - means, logvar = torch.chunk(self.encoder(x), 2, dim=1) + def encode(self, x, device=None): + x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :] + means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x, timestep=0.05, noise_scale=0.025): - if self.timestep_conditioning: #TODO: seed - x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep) + def decode_output_shape(self, input_shape): + return self.decoder.decode_output_shape(input_shape) + def decode(self, x, output_buffer=None): + if self.timestep_conditioning: #TODO: seed + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py new file mode 100644 index 000000000..2481d8bdd --- /dev/null +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -0,0 +1,713 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import comfy.ops +import comfy.model_management +import numpy as np +import math + +ops = comfy.ops.disable_weight_init + +LRELU_SLOPE = 0.1 + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +# --------------------------------------------------------------------------- +# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 +# Adopted from https://github.com/NVIDIA/BigVGAN +# --------------------------------------------------------------------------- + + +def _sinc(x: torch.Tensor): + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride=1, + padding=True, + padding_mode="replicate", + kernel_size=12, + ): + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + def forward(self, x): + _, C, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"): + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + # Hann-windowed sinc filter — identical to torchaudio.functional.resample + # with its default parameters (rolloff=0.99, lowpass_filter_width=6). + # Uses replicate boundary padding, matching the reference resampler exactly. + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + t = (torch.arange(self.kernel_size) / ratio - width) * rolloff + t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2 + filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser-windowed sinc filter (BigVGAN default). + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + + self.register_buffer("filter", filter, persistent=persistent) + + def forward(self, x): + _, C, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + return self.lowpass(x) + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio=2, + down_ratio=2, + up_kernel_size=12, + down_kernel_size=12, + ): + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +# --------------------------------------------------------------------------- +# BigVGAN v2 activations (Snake / SnakeBeta) +# --------------------------------------------------------------------------- + + +class Snake(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) + if self.alpha_logscale: + a = torch.exp(a) + return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) + + +class SnakeBeta(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.beta = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.beta.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) + b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device) + if self.alpha_logscale: + a = torch.exp(a) + b = torch.exp(b) + return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2) + + +# --------------------------------------------------------------------------- +# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity) +# --------------------------------------------------------------------------- + + +class AMPBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"): + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + self.acts1 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs1))] + ) + self.acts2 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs2))] + ) + + def forward(self, x): + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = x + xt + return x + + +# --------------------------------------------------------------------------- +# HiFi-GAN residual blocks +# --------------------------------------------------------------------------- + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + +class Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + + Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1"). + """ + + def __init__(self, config=None): + super(Vocoder, self).__init__() + + if config is None: + config = self.get_default_config() + + resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) + upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4]) + resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_initial_channel = config.get("upsample_initial_channel", 1024) + stereo = config.get("stereo", True) + activation = config.get("activation", "snake") + use_bias_at_final = config.get("use_bias_at_final", True) + + + # "output_sample_rate" is not present in recent checkpoint configs. + # When absent (None), AudioVAE.output_sample_rate computes it as: + # sample_rate * vocoder.upsample_factor / mel_hop_length + # where upsample_factor = product of all upsample stride lengths, + # and mel_hop_length is loaded from the autoencoder config at + # preprocessing.stft.hop_length (see CausalAudioAutoencoder). + self.output_sample_rate = config.get("output_sample_rate") + self.resblock = config.get("resblock", "1") + self.use_tanh_at_final = config.get("use_tanh_at_final", True) + self.apply_final_activation = config.get("apply_final_activation", True) + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + + in_channels = 128 if stereo else 64 + self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + + if self.resblock == "1": + resblock_cls = ResBlock1 + elif self.resblock == "2": + resblock_cls = ResBlock2 + elif self.resblock == "AMP1": + resblock_cls = AMPBlock1 + else: + raise ValueError(f"Unknown resblock type: {self.resblock}") + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + ops.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): + if self.resblock == "AMP1": + self.resblocks.append(resblock_cls(ch, k, d, activation=activation)) + else: + self.resblocks.append(resblock_cls(ch, k, d)) + + out_channels = 2 if stereo else 1 + if self.resblock == "AMP1": + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.act_post = Activation1d(act_cls(ch)) + else: + self.act_post = nn.LeakyReLU() + + self.conv_post = ops.Conv1d( + ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final + ) + + self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + + + def get_default_config(self): + """Generate default configuration for the vocoder.""" + + config = { + "resblock_kernel_sizes": [3, 7, 11], + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "upsample_initial_channel": 1024, + "stereo": True, + "resblock": "1", + "activation": "snake", + "use_bias_at_final": True, + "use_tanh_at_final": True, + } + + return config + + def forward(self, x): + """ + Forward pass of the vocoder. + + Args: + x: Input spectrogram tensor. Can be: + - 3D: (batch_size, channels, time_steps) for mono + - 4D: (batch_size, 2, channels, time_steps) for stereo + + Returns: + Audio tensor of shape (batch_size, out_channels, audio_length) + """ + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + if self.resblock != "AMP1": + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = self.act_post(x) + x = self.conv_post(x) + + if self.apply_final_activation: + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, -1, 1) + + return x + + +class _STFTFn(nn.Module): + """Implements STFT as a convolution with precomputed DFT × Hann-window bases. + + The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal + Hann window are stored as buffers and loaded from the checkpoint. Using the exact + bfloat16 bases from training ensures the mel values fed to the BWE generator are + bit-identical to what it was trained on. + """ + + def __init__(self, filter_length: int, hop_length: int, win_length: int): + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + + def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute magnitude and phase spectrogram from a batch of waveforms. + + Applies causal (left-only) padding of win_length - hop_length samples so that + each output frame depends only on past and present input — no lookahead. + The STFT is computed by convolving the padded signal with forward_basis. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + Computed in float32 for numerical stability, then cast back to + the input dtype. + """ + if y.dim() == 2: + y = y.unsqueeze(1) # (B, 1, T) + left_pad = max(0, self.win_length - self.hop_length) # causal: left-only + y = F.pad(y, (left_pad, 0)) + spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real ** 2 + imag ** 2) + phase = torch.atan2(imag.float(), real.float()).to(real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. + + Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input + waveform and projecting the linear magnitude spectrum onto the mel filterbank. + + The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint + (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). + """ + + def __init__( + self, + filter_length: int, + hop_length: int, + win_length: int, + n_mel_channels: int, + sampling_rate: int, + mel_fmin: float, + mel_fmax: float, + ): + super().__init__() + self.stft_fn = _STFTFn(filter_length, hop_length, win_length) + + n_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) + + def mel_spectrogram( + self, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute log-mel spectrogram and auxiliary spectral quantities. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). + Computed as log(clamp(mel_basis @ magnitude, min=1e-5)). + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). + """ + magnitude, phase = self.stft_fn(y) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class VocoderWithBWE(torch.nn.Module): + """Vocoder with bandwidth extension (BWE) for higher sample rate output. + + Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples + to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. + """ + + def __init__(self, config): + super().__init__() + vocoder_config = config["vocoder"] + bwe_config = config["bwe"] + + self.vocoder = Vocoder(config=vocoder_config) + self.bwe_generator = Vocoder( + config={**bwe_config, "apply_final_activation": False} + ) + + self.input_sample_rate = bwe_config["input_sampling_rate"] + self.output_sample_rate = bwe_config["output_sampling_rate"] + self.hop_length = bwe_config["hop_length"] + + self.mel_stft = MelSTFT( + filter_length=bwe_config["n_fft"], + hop_length=bwe_config["hop_length"], + win_length=bwe_config["n_fft"], + n_mel_channels=bwe_config["num_mels"], + sampling_rate=bwe_config["input_sampling_rate"], + mel_fmin=0.0, + mel_fmax=bwe_config["input_sampling_rate"] / 2.0, + ) + self.resampler = UpSample1d( + ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"], + persistent=False, + window_type="hann", + ) + + def _compute_mel(self, audio): + """Compute log-mel spectrogram from waveform using causal STFT bases.""" + B, C, T = audio.shape + flat = audio.reshape(B * C, -1) # (B*C, T) + mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) + + def forward(self, mel_spec): + x = self.vocoder(mel_spec) + _, _, T_low = x.shape + T_out = T_low * self.output_sample_rate // self.input_sample_rate + + remainder = T_low % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + mel = self._compute_mel(x) + residual = self.bwe_generator(mel) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + + return torch.clamp(residual + skip, -1, 1)[..., :T_out] diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 5628e2ba3..9e432d5c0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -13,10 +13,54 @@ from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension +import comfy.utils +from comfy.ldm.chroma_radiance.layers import NerfEmbedder -def modulate(x, scale): - return x * (1 + scale.unsqueeze(1)) +def invert_slices(slices, length): + sorted_slices = sorted(slices) + result = [] + current = 0 + + for start, end in sorted_slices: + if current < start: + result.append((current, start)) + current = max(current, end) + + if current < length: + result.append((current, length)) + + return result + + +def modulate(x, scale, timestep_zero_index=None): + if timestep_zero_index is None: + return x * (1 + scale.unsqueeze(1)) + else: + scale = (1 + scale.unsqueeze(1)) + actual_batch = scale.size(0) // 2 + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= scale[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= scale[:actual_batch] + return x + + +def apply_gate(gate, x, timestep_zero_index=None): + if timestep_zero_index is None: + return gate * x + else: + actual_batch = gate.size(0) // 2 + + slices = timestep_zero_index + invert = invert_slices(timestep_zero_index, x.shape[1]) + for s in slices: + x[:, s[0]:s[1]] *= gate[actual_batch:] + for s in invert: + x[:, s[0]:s[1]] *= gate[:actual_batch] + return x ############################################################################# # Core NextDiT Model # @@ -258,6 +302,7 @@ class JointTransformerBlock(nn.Module): x_mask: torch.Tensor, freqs_cis: torch.Tensor, adaln_input: Optional[torch.Tensor]=None, + timestep_zero_index=None, transformer_options={}, ): """ @@ -276,18 +321,18 @@ class JointTransformerBlock(nn.Module): assert adaln_input is not None scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) - x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( + x = x + apply_gate(gate_msa.unsqueeze(1).tanh(), self.attention_norm2( clamp_fp16(self.attention( - modulate(self.attention_norm1(x), scale_msa), + modulate(self.attention_norm1(x), scale_msa, timestep_zero_index=timestep_zero_index), x_mask, freqs_cis, transformer_options=transformer_options, - )) + ))), timestep_zero_index=timestep_zero_index ) - x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( + x = x + apply_gate(gate_mlp.unsqueeze(1).tanh(), self.ffn_norm2( clamp_fp16(self.feed_forward( - modulate(self.ffn_norm1(x), scale_mlp), - )) + modulate(self.ffn_norm1(x), scale_mlp, timestep_zero_index=timestep_zero_index), + ))), timestep_zero_index=timestep_zero_index ) else: assert adaln_input is None @@ -345,13 +390,37 @@ class FinalLayer(nn.Module): ), ) - def forward(self, x, c): + def forward(self, x, c, timestep_zero_index=None): scale = self.adaLN_modulation(c) - x = modulate(self.norm_final(x), scale) + x = modulate(self.norm_final(x), scale, timestep_zero_index=timestep_zero_index) x = self.linear(x) return x +def pad_zimage(feats, pad_token, pad_tokens_multiple): + pad_extra = (-feats.shape[1]) % pad_tokens_multiple + return torch.cat((feats, pad_token.to(device=feats.device, dtype=feats.dtype, copy=True).unsqueeze(0).repeat(feats.shape[0], pad_extra, 1)), dim=1), pad_extra + + +def pos_ids_x(start_t, H_tokens, W_tokens, batch_size, device, transformer_options={}): + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) + + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) + x_pos_ids = torch.zeros((batch_size, H_tokens * W_tokens, 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = start_t + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() + return x_pos_ids + + class NextDiT(nn.Module): """ Diffusion model with a Transformer backbone. @@ -378,10 +447,12 @@ class NextDiT(nn.Module): time_scale=1.0, pad_tokens_multiple=None, clip_text_dim=None, + siglip_feat_dim=None, image_model=None, device=None, dtype=None, operations=None, + **kwargs, ) -> None: super().__init__() self.dtype = dtype @@ -491,7 +562,43 @@ class NextDiT(nn.Module): for layer_id in range(n_layers) ] ) - self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + if siglip_feat_dim is not None: + self.siglip_embedder = nn.Sequential( + operation_settings.get("operations").RMSNorm(siglip_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + siglip_feat_dim, + dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.siglip_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=False, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + self.siglip_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + else: + self.siglip_embedder = None + self.siglip_refiner = None + self.siglip_pad_token = None + + # This norm final is in the lumina 2.0 code but isn't actually used for anything. + # self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) if self.pad_tokens_multiple is not None: @@ -530,70 +637,168 @@ class NextDiT(nn.Module): imgs = torch.stack(imgs, dim=0) return imgs - def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} - ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: - bsz = len(x) - pH = pW = self.patch_size - device = x[0].device - orig_x = x - - if self.pad_tokens_multiple is not None: - pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple - cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) + def embed_cap(self, cap_feats=None, offset=0, bsz=1, device=None, dtype=None): + if cap_feats is not None: + cap_feats = self.cap_embedder(cap_feats) + cap_feats_len = cap_feats.shape[1] + if self.pad_tokens_multiple is not None: + cap_feats, _ = pad_zimage(cap_feats, self.cap_pad_token, self.pad_tokens_multiple) + else: + cap_feats_len = 0 + cap_feats = self.cap_pad_token.to(device=device, dtype=dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) - cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 + offset + embeds = (cap_feats,) + freqs_cis = (self.rope_embedder(cap_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + + def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False, transformer_options={}): + bsz = 1 + pH = pW = self.patch_size + device = x.device + embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype) + + if (not omni) or self.siglip_embedder is None: + cap_feats_len = embeds[0].shape[1] + offset + embeds += (None,) + freqs_cis += (None,) + else: + cap_feats_len += offset + if siglip_feats is not None: + b, h, w, c = siglip_feats.shape + siglip_feats = siglip_feats.permute(0, 3, 1, 2).reshape(b, h * w, c) + siglip_feats = self.siglip_embedder(siglip_feats) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + siglip_pos_ids[:, :, 0] = cap_feats_len + 2 + siglip_pos_ids[:, :, 1] = (torch.linspace(0, h * 8 - 1, steps=h, dtype=torch.float32, device=device).floor()).view(-1, 1).repeat(1, w).flatten() + siglip_pos_ids[:, :, 2] = (torch.linspace(0, w * 8 - 1, steps=w, dtype=torch.float32, device=device).floor()).view(1, -1).repeat(h, 1).flatten() + if self.siglip_pad_token is not None: + siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check + siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra)) + else: + if self.siglip_pad_token is not None: + siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1) + siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device) + + if siglip_feats is None: + embeds += (None,) + freqs_cis += (None,) + else: + embeds += (siglip_feats,) + freqs_cis += (self.rope_embedder(siglip_pos_ids).movedim(1, 2),) B, C, H, W = x.shape x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - - rope_options = transformer_options.get("rope_options", None) - h_scale = 1.0 - w_scale = 1.0 - h_start = 0 - w_start = 0 - if rope_options is not None: - h_scale = rope_options.get("scale_y", 1.0) - w_scale = rope_options.get("scale_x", 1.0) - - h_start = rope_options.get("shift_y", 0.0) - w_start = rope_options.get("shift_x", 0.0) - - H_tokens, W_tokens = H // pH, W // pW - x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) - x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 - x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() - x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - + x_pos_ids = pos_ids_x(cap_feats_len + 1, H // pH, W // pW, bsz, device, transformer_options=transformer_options) if self.pad_tokens_multiple is not None: - pad_extra = (-x.shape[1]) % self.pad_tokens_multiple - x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x, pad_extra = pad_zimage(x, self.x_pad_token, self.pad_tokens_multiple) x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) + embeds += (x,) + freqs_cis += (self.rope_embedder(x_pos_ids).movedim(1, 2),) + return embeds, freqs_cis, cap_feats_len + len(freqs_cis) - 1 + + + def patchify_and_embed( + self, x: torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={} + ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: + bsz = x.shape[0] + cap_mask = None # TODO? + main_siglip = None + orig_x = x + + embeds = ([], [], []) + freqs_cis = ([], [], []) + leftover_cap = [] + + start_t = 0 + omni = len(ref_latents) > 0 + if omni: + for i, ref in enumerate(ref_latents): + if i < len(ref_contexts): + ref_con = ref_contexts[i] + else: + ref_con = None + if i < len(siglip_feats): + sig_feat = siglip_feats[i] + else: + sig_feat = None + + out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options) + for i, e in enumerate(out[0]): + if e is not None: + embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz)) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + leftover_cap = ref_contexts[len(ref_latents):] + + H, W = x.shape[-2], x.shape[-1] + img_sizes = [(H, W)] * bsz + out = self.embed_all(x, cap_feats, main_siglip, offset=start_t, omni=omni, transformer_options=transformer_options) + img_len = out[0][-1].shape[1] + cap_len = out[0][0].shape[1] + for i, e in enumerate(out[0]): + if e is not None: + e = comfy.utils.repeat_to_batch_size(e, bsz) + embeds[i].append(e) + freqs_cis[i].append(out[1][i]) + start_t = out[2] + + for cap in leftover_cap: + out = self.embed_cap(cap, offset=start_t, bsz=bsz, device=x.device, dtype=x.dtype) + cap_len += out[0][0].shape[1] + embeds[0].append(comfy.utils.repeat_to_batch_size(out[0][0], bsz)) + freqs_cis[0].append(out[1][0]) + start_t += out[2] patches = transformer_options.get("patches", {}) # refine context + cap_feats = torch.cat(embeds[0], dim=1) + cap_freqs_cis = torch.cat(freqs_cis[0], dim=1) for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + + feats = (cap_feats,) + fc = (cap_freqs_cis,) + + if omni and len(embeds[1]) > 0: + siglip_mask = None + siglip_feats_combined = torch.cat(embeds[1], dim=1) + siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1) + if self.siglip_refiner is not None: + for layer in self.siglip_refiner: + siglip_feats_combined = layer(siglip_feats_combined, siglip_mask, siglip_feats_freqs_cis, transformer_options=transformer_options) + feats += (siglip_feats_combined,) + fc += (siglip_feats_freqs_cis,) padded_img_mask = None + x = torch.cat(embeds[-1], dim=1) + fc_x = torch.cat(freqs_cis[-1], dim=1) + if omni: + timestep_zero_index = [(x.shape[1] - img_len, x.shape[1])] + else: + timestep_zero_index = None + x_input = x for i, layer in enumerate(self.noise_refiner): - x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + x = layer(x, padded_img_mask, fc_x, t, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "noise_refiner" in patches: for p in patches["noise_refiner"]: - out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) + out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": fc_x, "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"}) if "img" in out: x = out["img"] - padded_full_embed = torch.cat((cap_feats, x), dim=1) + padded_full_embed = torch.cat(feats + (x,), dim=1) + if timestep_zero_index is not None: + ind = padded_full_embed.shape[1] - x.shape[1] + timestep_zero_index = [(ind + x.shape[1] - img_len, ind + x.shape[1])] + timestep_zero_index.append((feats[0].shape[1] - cap_len, feats[0].shape[1])) + mask = None - img_sizes = [(H, W)] * bsz - l_effective_cap_len = [cap_feats.shape[1]] * bsz - return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + l_effective_cap_len = [padded_full_embed.shape[1] - img_len] * bsz + return padded_full_embed, mask, img_sizes, l_effective_cap_len, torch.cat(fc + (fc_x,), dim=1), timestep_zero_index def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( @@ -603,7 +808,11 @@ class NextDiT(nn.Module): ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) # def forward(self, x, t, cap_feats, cap_mask): - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -618,20 +827,18 @@ class NextDiT(nn.Module): t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t - cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - if self.clip_text_pooled_proj is not None: pooled = kwargs.get("clip_text_pooled", None) if pooled is not None: pooled = self.clip_text_pooled_proj(pooled) else: - pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, ref_latents=ref_latents, ref_contexts=ref_contexts, siglip_feats=siglip_feats, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) transformer_options["total_blocks"] = len(self.layers) @@ -639,7 +846,7 @@ class NextDiT(nn.Module): img_input = img for i, layer in enumerate(self.layers): transformer_options["block_index"] = i - img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) if "double_block" in patches: for p in patches["double_block"]: out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) @@ -648,8 +855,271 @@ class NextDiT(nn.Module): if "txt" in out: img[:, :cap_size[0]] = out["txt"] - img = self.final_layer(img, adaln_input) + img = self.final_layer(img, adaln_input, timestep_zero_index=timestep_zero_index) img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] - return -img + +############################################################################# +# Pixel Space Decoder Components # +############################################################################# + +def _modulate_shift_scale(x, shift, scale): + return x * (1 + scale) + shift + + +class PixelResBlock(nn.Module): + """ + Residual block with AdaLN modulation, zero-initialised so it starts as + an identity at the beginning of training. + """ + + def __init__(self, channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device) + self.mlp = nn.Sequential( + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) + h = _modulate_shift_scale(self.in_ln(x), shift, scale) + h = self.mlp(h) + return x + gate * h + + +class DCTFinalLayer(nn.Module): + """Zero-initialised output projection (adopted from DiT).""" + + def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + """ + Small MLP decoder head for the pixel-space variant. + + Takes per-patch pixel values and a per-patch conditioning vector from the + transformer backbone and predicts the denoised pixel values. + + x : [B*N, P^2, C] – noisy pixel values per patch position + c : [B*N, dim] – backbone hidden state per patch (conditioning) + → [B*N, P^2, C] + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + z_channels: int, + num_res_blocks: int, + max_freqs: int = 8, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + + # Project backbone hidden state → per-patch conditioning + self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device) + + # Input projection with DCT positional encoding + self.input_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size_input=model_channels, + max_freqs=max_freqs, + dtype=dtype, + device=device, + operations=operations, + ) + + # Residual blocks + self.res_blocks = nn.ModuleList([ + PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks) + ]) + + # Output projection + self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B*N, 1, P^2*C], c: [B*N, dim] + original_dtype = x.dtype + weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype) + x = self.input_embedder(x) # [B*N, 1, model_channels] + y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels] + x = x.to(weight_dtype) + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C] + + +############################################################################# +# NextDiT – Pixel Space # +############################################################################# + +class NextDiTPixelSpace(NextDiT): + """ + Pixel-space variant of NextDiT. + + Identical transformer backbone to NextDiT, but the output head is replaced + with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values + per patch rather than a single affine projection. + + Key differences vs NextDiT: + • ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead. + • ``_forward`` stores the raw patchified pixel values before the backbone + embedding and feeds them to ``dec_net`` together with the per-patch + backbone hidden states. + • Supports optional x0 prediction via ``use_x0``. + """ + + def __init__( + self, + # decoder-specific + decoder_hidden_size: int = 3840, + decoder_num_res_blocks: int = 4, + decoder_max_freqs: int = 8, + decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels) + use_x0: bool = False, + # all NextDiT args forwarded unchanged + **kwargs, + ): + super().__init__(**kwargs) + + # Remove the latent-space final layer – not used in pixel space + del self.final_layer + + patch_size = kwargs.get("patch_size", 2) + in_channels = kwargs.get("in_channels", 4) + dim = kwargs.get("dim", 4096) + + # decoder_in_channels is the full flattened patch: patch_size^2 * in_channels + dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels + + self.dec_net = SimpleMLPAdaLN( + in_channels=dec_in_ch, + model_channels=decoder_hidden_size, + out_channels=dec_in_ch, + z_channels=dim, + num_res_blocks=decoder_num_res_blocks, + max_freqs=decoder_max_freqs, + dtype=kwargs.get("dtype"), + device=kwargs.get("device"), + operations=kwargs.get("operations"), + ) + + if use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + # ------------------------------------------------------------------ + # Forward — mirrors NextDiT._forward exactly, replacing final_layer + # with the pixel-space dec_net decoder. + # ------------------------------------------------------------------ + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + + t = 1.0 - timesteps + cap_feats = context + cap_mask = attention_mask + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) + adaln_input = t + + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + + # ---- capture raw pixel patches before patchify_and_embed embeds them ---- + pH = pW = self.patch_size + B, C, H, W = x.shape + pixel_patches = ( + x.view(B, C, H // pH, pH, W // pW, pW) + .permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C] + .flatten(3) # [B, Ht, Wt, pH*pW*C] + .flatten(1, 2) # [B, N, pH*pW*C] + ) + N = pixel_patches.shape[1] + # decoder sees one token per patch: [B*N, 1, P^2*C] + pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C) + + patches = transformer_options.get("patches", {}) + x_is_tensor = isinstance(x, torch.Tensor) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed( + x, cap_feats, cap_mask, adaln_input, num_tokens, + ref_latents=ref_latents, ref_contexts=ref_contexts, + siglip_feats=siglip_feats, transformer_options=transformer_options + ) + freqs_cis = freqs_cis.to(img.device) + + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" + img_input = img + for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] + + # ---- pixel-space decoder (replaces final_layer + unpatchify) ---- + # img may have padding tokens beyond N; only the first N are real image patches + img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim] + decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim] + + output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C] + output = output.reshape(B, N, -1) # [B, N, P^2*C] + + # prepend zero cap placeholder so unpatchify indexing works unchanged + cap_placeholder = torch.zeros( + B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype + ) + img_out = self.unpatchify( + torch.cat([cap_placeholder, output], dim=1), + img_size, cap_size, return_tensor=x_is_tensor + )[:, :, :h, :w] + + return -img_out + + def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + # _forward returns neg_x0 = -x0 (negated decoder output). + # + # Reference inference (working_inference_reference.py): + # out = _forward(img, t) # = -x0 + # pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual] + # img += (t_prev - t_curr) * pred # Euler step + # + # ComfyUI's Euler sampler does the same: + # x_next = x + (sigma_next - sigma) * model_output + # So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t + neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + return (x - neg_x0) / timesteps.view(-1, 1, 1, 1) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0..b193fe5e8 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,6 +30,13 @@ except ImportError as e: raise e exit(-1) +SAGE_ATTENTION3_IS_AVAILABLE = False +try: + from sageattn3 import sageattn3_blackwell + SAGE_ATTENTION3_IS_AVAILABLE = True +except ImportError: + pass + FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func @@ -365,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) if first_op_done == False: model_management.soft_empty_cache(True) if cleared_cache == False: @@ -517,6 +525,9 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if kwargs.get("low_precision_attention", True) is False: + return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs) + exception_fallback = False if skip_reshape: b, _, _, dim_head = q.shape @@ -563,6 +574,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = out.reshape(b, -1, heads * dim_head) return out +@wrap_attn +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False + if (q.device.type != "cuda" or + q.dtype not in (torch.float16, torch.bfloat16) or + mask is not None): + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + B, H, L, D = q.shape + if H != heads: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + q_s, k_s, v_s = q, k, v + N = q.shape[2] + dim_head = D + else: + B, N, inner_dim = q.shape + if inner_dim % heads != 0: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + dim_head = inner_dim // heads + + if dim_head >= 256 or N <= 1024: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if not skip_reshape: + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + + try: + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + except Exception as e: + exception_fallback = True + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + + if exception_fallback: + if not skip_reshape: + del q_s, k_s, v_s + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + if not skip_output_reshape: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + else: + if skip_output_reshape: + pass + else: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return out try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @@ -650,6 +748,8 @@ optimized_attention_masked = optimized_attention # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION3_IS_AVAILABLE: + register_attention_function("sage3", attention3_sage) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 681a55db5..fcbaa074f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -14,10 +14,13 @@ if model_management.xformers_enabled_vae(): import xformers.ops def torch_cat_if_needed(xl, dim): + xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: return torch.cat(xl, dim) - else: + elif len(xl) == 1: return xl[0] + else: + return None def get_timestep_embedding(timesteps, embedding_dim): """ @@ -99,19 +102,7 @@ class VideoConv3d(nn.Module): return self.conv(x) def interpolate_up(x, scale_factor): - try: - return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest") - except: #operation not implemented for bf16 - orig_shape = list(x.shape) - out_shape = orig_shape[:2] - for i in range(len(orig_shape) - 2): - out_shape.append(round(orig_shape[i + 2] * scale_factor[i])) - out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device) - split = 8 - l = out.shape[1] // split - for i in range(0, out.shape[1], l): - out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype) - return out + return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest") class Upsample(nn.Module): def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0): @@ -267,7 +258,8 @@ def slice_attention(q, k, v): r1[:, :, i:end] = torch.bmm(v, s2) del s2 break - except model_management.OOM_EXCEPTION as e: + except Exception as e: + model_management.raise_non_oom(e) model_management.soft_empty_cache(True) steps *= 2 if steps > 128: @@ -323,7 +315,8 @@ def pytorch_attention(q, k, v): try: out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = out.transpose(2, 3).reshape(orig_shape) - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") oom_fallback = True if oom_fallback: @@ -394,7 +387,8 @@ class Model(nn.Module): attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = self.ch*4 self.num_resolutions = len(ch_mult) @@ -548,7 +542,8 @@ class Encoder(nn.Module): conv3d=False, time_compress=None, **ignore_kwargs): super().__init__() - if use_linear_attn: attn_type = "linear" + if use_linear_attn: + attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53cac..295310df6 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -18,6 +18,8 @@ import comfy.patcher_extension import comfy.ops ops = comfy.ops.disable_weight_init +from ..sdpose import HeatmapHead + class TimestepBlock(nn.Module): """ Any module where forward() takes timestep embeddings as a second argument. @@ -441,6 +443,7 @@ class UNetModel(nn.Module): disable_temporal_crossattention=False, max_ddpm_temb_period=10000, attn_precision=None, + heatmap_head=False, device=None, operations=ops, ): @@ -827,6 +830,9 @@ class UNetModel(nn.Module): #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits ) + if heatmap_head: + self.heatmap_head = HeatmapHead(device=device, dtype=self.dtype, operations=operations) + def forward(self, x, timesteps=None, context=None, y=None, control=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, diff --git a/comfy/ldm/modules/ema.py b/comfy/ldm/modules/ema.py index bded25019..96ee6e895 100644 --- a/comfy/ldm/modules/ema.py +++ b/comfy/ldm/modules/ema.py @@ -45,7 +45,7 @@ class LitEma(nn.Module): shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def copy_to(self, model): m_param = dict(model.named_parameters()) @@ -54,7 +54,7 @@ class LitEma(nn.Module): if m_param[key].requires_grad: m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) else: - assert not key in self.m_name2s_name + assert key not in self.m_name2s_name def store(self, parameters): """ diff --git a/comfy/ldm/modules/sdpose.py b/comfy/ldm/modules/sdpose.py new file mode 100644 index 000000000..d67b60b76 --- /dev/null +++ b/comfy/ldm/modules/sdpose.py @@ -0,0 +1,130 @@ +import torch +import numpy as np +from scipy.ndimage import gaussian_filter + +class HeatmapHead(torch.nn.Module): + def __init__( + self, + in_channels=640, + out_channels=133, + input_size=(768, 1024), + heatmap_scale=4, + deconv_out_channels=(640,), + deconv_kernel_sizes=(4,), + conv_out_channels=(640,), + conv_kernel_sizes=(1,), + final_layer_kernel_size=1, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.heatmap_size = (input_size[0] // heatmap_scale, input_size[1] // heatmap_scale) + self.scale_factor = ((np.array(input_size) - 1) / (np.array(self.heatmap_size) - 1)).astype(np.float32) + + # Deconv layers + if deconv_out_channels: + deconv_layers = [] + for out_ch, kernel_size in zip(deconv_out_channels, deconv_kernel_sizes): + if kernel_size == 4: + padding, output_padding = 1, 0 + elif kernel_size == 3: + padding, output_padding = 1, 1 + elif kernel_size == 2: + padding, output_padding = 0, 0 + else: + raise ValueError(f'Unsupported kernel size {kernel_size}') + + deconv_layers.extend([ + operations.ConvTranspose2d(in_channels, out_ch, kernel_size, + stride=2, padding=padding, output_padding=output_padding, bias=False, device=device, dtype=dtype), + torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), + torch.nn.SiLU(inplace=True) + ]) + in_channels = out_ch + self.deconv_layers = torch.nn.Sequential(*deconv_layers) + else: + self.deconv_layers = torch.nn.Identity() + + # Conv layers + if conv_out_channels: + conv_layers = [] + for out_ch, kernel_size in zip(conv_out_channels, conv_kernel_sizes): + padding = (kernel_size - 1) // 2 + conv_layers.extend([ + operations.Conv2d(in_channels, out_ch, kernel_size, + stride=1, padding=padding, device=device, dtype=dtype), + torch.nn.InstanceNorm2d(out_ch, device=device, dtype=dtype), + torch.nn.SiLU(inplace=True) + ]) + in_channels = out_ch + self.conv_layers = torch.nn.Sequential(*conv_layers) + else: + self.conv_layers = torch.nn.Identity() + + self.final_layer = operations.Conv2d(in_channels, out_channels, kernel_size=final_layer_kernel_size, padding=final_layer_kernel_size // 2, device=device, dtype=dtype) + + def forward(self, x): # Decode heatmaps to keypoints + heatmaps = self.final_layer(self.conv_layers(self.deconv_layers(x))) + heatmaps_np = heatmaps.float().cpu().numpy() # (B, K, H, W) + B, K, H, W = heatmaps_np.shape + + batch_keypoints = [] + batch_scores = [] + + for b in range(B): + hm = heatmaps_np[b].copy() # (K, H, W) + + # --- vectorised argmax --- + flat = hm.reshape(K, -1) + idx = np.argmax(flat, axis=1) + scores = flat[np.arange(K), idx].copy() + y_locs, x_locs = np.unravel_index(idx, (H, W)) + keypoints = np.stack([x_locs, y_locs], axis=-1).astype(np.float32) # (K, 2) in heatmap space + invalid = scores <= 0. + keypoints[invalid] = -1 + + # --- DARK sub-pixel refinement (UDP) --- + # 1. Gaussian blur with max-preserving normalisation + border = 5 # (kernel-1)//2 for kernel=11 + for k in range(K): + origin_max = np.max(hm[k]) + dr = np.zeros((H + 2 * border, W + 2 * border), dtype=np.float32) + dr[border:-border, border:-border] = hm[k].copy() + dr = gaussian_filter(dr, sigma=2.0) + hm[k] = dr[border:-border, border:-border].copy() + cur_max = np.max(hm[k]) + if cur_max > 0: + hm[k] *= origin_max / cur_max + # 2. Log-space for Taylor expansion + np.clip(hm, 1e-3, 50., hm) + np.log(hm, hm) + # 3. Hessian-based Newton step + hm_pad = np.pad(hm, ((0, 0), (1, 1), (1, 1)), mode='edge').flatten() + index = keypoints[:, 0] + 1 + (keypoints[:, 1] + 1) * (W + 2) + index += (W + 2) * (H + 2) * np.arange(0, K) + index = index.astype(int).reshape(-1, 1) + i_ = hm_pad[index] + ix1 = hm_pad[index + 1] + iy1 = hm_pad[index + W + 2] + ix1y1 = hm_pad[index + W + 3] + ix1_y1_ = hm_pad[index - W - 3] + ix1_ = hm_pad[index - 1] + iy1_ = hm_pad[index - 2 - W] + dx = 0.5 * (ix1 - ix1_) + dy = 0.5 * (iy1 - iy1_) + derivative = np.concatenate([dx, dy], axis=1).reshape(K, 2, 1) + dxx = ix1 - 2 * i_ + ix1_ + dyy = iy1 - 2 * i_ + iy1_ + dxy = 0.5 * (ix1y1 - ix1 - iy1 + i_ + i_ - ix1_ - iy1_ + ix1_y1_) + hessian = np.concatenate([dxx, dxy, dxy, dyy], axis=1).reshape(K, 2, 2) + hessian = np.linalg.inv(hessian + np.finfo(np.float32).eps * np.eye(2)) + keypoints -= np.einsum('imn,ink->imk', hessian, derivative).squeeze(axis=-1) + + # --- restore to input image space --- + keypoints = keypoints * self.scale_factor + keypoints[invalid] = -1 + + batch_keypoints.append(keypoints) + batch_scores.append(scores) + + return batch_keypoints, batch_scores diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fab145f1c..f982afc2b 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking( try: attn_probs = attn_scores.softmax(dim=-1) del attn_scores - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined torch.exp(attn_scores, out=attn_scores) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py index a6d408104..c0aae9240 100644 --- a/comfy/ldm/qwen_image/controlnet.py +++ b/comfy/ldm/qwen_image/controlnet.py @@ -2,6 +2,196 @@ import torch import math from .model import QwenImageTransformer2DModel +from .model import QwenImageTransformerBlock + + +class QwenImageFunControlBlock(QwenImageTransformerBlock): + def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None): + super().__init__( + dim=dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dtype=dtype, + device=device, + operations=operations, + ) + self.has_before_proj = has_before_proj + if has_before_proj: + self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + +class QwenImageFunControlNetModel(torch.nn.Module): + def __init__( + self, + control_in_features=132, + inner_dim=3072, + num_attention_heads=24, + attention_head_dim=128, + num_control_blocks=5, + main_model_double=60, + injection_layers=(0, 12, 24, 36, 48), + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.main_model_double = main_model_double + self.injection_layers = tuple(injection_layers) + # Keep base hint scaling at 1.0 so user-facing strength behaves similarly + # to the reference Gen2/VideoX implementation around strength=1. + self.hint_scale = 1.0 + self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype) + + self.control_blocks = torch.nn.ModuleList([]) + for i in range(num_control_blocks): + self.control_blocks.append( + QwenImageFunControlBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + has_before_proj=(i == 0), + dtype=dtype, + device=device, + operations=operations, + ) + ) + + def _process_hint_tokens(self, hint): + if hint is None: + return None + if hint.ndim == 4: + hint = hint.unsqueeze(2) + + # Fun checkpoints are trained with 33 latent channels before 2x2 packing: + # [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features. + # Default behavior (no inpaint input in stock Apply ControlNet) should use + # zeros for mask/inpaint branches, matching VideoX fallback semantics. + expected_c = self.control_img_in.weight.shape[1] // 4 + if hint.shape[1] == 16 and expected_c == 33: + zeros_mask = torch.zeros_like(hint[:, :1]) + zeros_inpaint = torch.zeros_like(hint) + hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1) + + bs, c, t, h, w = hint.shape + hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view( + orig_shape[0], + orig_shape[1], + orig_shape[-3], + orig_shape[-2] // 2, + 2, + orig_shape[-1] // 2, + 2, + ) + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) + hidden_states = hidden_states.reshape( + bs, + t * ((h + 1) // 2) * ((w + 1) // 2), + c * 4, + ) + + expected_in = self.control_img_in.weight.shape[1] + cur_in = hidden_states.shape[-1] + if cur_in < expected_in: + pad = torch.zeros( + (hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + hidden_states = torch.cat([hidden_states, pad], dim=-1) + elif cur_in > expected_in: + hidden_states = hidden_states[:, :, :expected_in] + + return hidden_states + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + hint=None, + transformer_options={}, + base_model=None, + **kwargs, + ): + if base_model is None: + raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.") + + encoder_hidden_states_mask = attention_mask + # Keep attention mask disabled inside Fun control blocks to mirror + # VideoX behavior (they rely on seq lengths for RoPE, not masked attention). + encoder_hidden_states_mask = None + + hidden_states, img_ids, _ = base_model.process_img(x) + hint_tokens = self._process_hint_tokens(hint) + if hint_tokens is None: + raise RuntimeError("Qwen Fun ControlNet requires a control hint image.") + + if hint_tokens.shape[1] != hidden_states.shape[1]: + max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1]) + hint_tokens = hint_tokens[:, :max_tokens] + hidden_states = hidden_states[:, :max_tokens] + img_ids = img_ids[:, :max_tokens] + + txt_start = round( + max( + ((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, + ((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, + ) + ) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous() + + hidden_states = base_model.img_in(hidden_states) + encoder_hidden_states = base_model.txt_norm(context) + encoder_hidden_states = base_model.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + base_model.time_text_embed(timesteps, hidden_states) + if guidance is None + else base_model.time_text_embed(timesteps, guidance, hidden_states) + ) + + c = self.control_img_in(hint_tokens) + + for i, block in enumerate(self.control_blocks): + if i == 0: + c_in = block.before_proj(c) + hidden_states + all_c = [] + else: + all_c = list(torch.unbind(c, dim=0)) + c_in = all_c.pop(-1) + + encoder_hidden_states, c_out = block( + hidden_states=c_in, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + c_skip = block.after_proj(c_out) * self.hint_scale + all_c += [c_skip, c_out] + c = torch.stack(all_c, dim=0) + + hints = torch.unbind(c, dim=0)[:-1] + + controlnet_block_samples = [None] * self.main_model_double + for local_idx, base_idx in enumerate(self.injection_layers): + if local_idx < len(hints) and base_idx < len(controlnet_block_samples): + controlnet_block_samples[base_idx] = hints[local_idx] + + return {"input": controlnet_block_samples} class QwenImageControlNetModel(QwenImageTransformer2DModel): diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 00c597535..0862f72f7 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -149,6 +149,9 @@ class Attention(nn.Module): seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] + transformer_patches = transformer_options.get("patches", {}) + extra_options = transformer_options.copy() + # Project and reshape to BHND format (batch, heads, seq, dim) img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() @@ -167,11 +170,24 @@ class Attention(nn.Module): joint_key = torch.cat([txt_key, img_key], dim=2) joint_value = torch.cat([txt_value, img_value], dim=2) + if encoder_hidden_states_mask is not None: + attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) + attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask + else: + attn_mask = None + + extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]] + if "attn1_patch" in transformer_patches: + patch = transformer_patches["attn1_patch"] + for p in patch: + out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options) + joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask) + joint_query = apply_rope1(joint_query, image_rotary_emb) joint_key = apply_rope1(joint_key, image_rotary_emb) joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, - attention_mask, transformer_options=transformer_options, + attn_mask, transformer_options=transformer_options, skip_reshape=True) txt_attn_output = joint_hidden_states[:, :seq_txt, :] @@ -430,11 +446,15 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = context encoder_hidden_states_mask = attention_mask + if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask): + encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max + hidden_states, img_ids, orig_shape = self.process_img(x) num_embeds = hidden_states.shape[1] timestep_zero_index = None if ref_latents is not None: + ref_num_tokens = [] h = 0 w = 0 index = 0 @@ -465,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module): kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) hidden_states = torch.cat([hidden_states, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) + ref_num_tokens.append(kontext.shape[1]) if timestep_zero: if index > 0: timestep = torch.cat([timestep, timestep * 0], dim=0) timestep_zero_index = num_embeds + transformer_options = transformer_options.copy() + transformer_options["reference_image_num_tokens"] = ref_num_tokens txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) - ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() - del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states) @@ -486,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + if "post_input" in patches: + for p in patches["post_input"]: + out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + img_ids = out["img_ids"] + txt_ids = out["txt_ids"] + + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() + del ids, txt_ids, img_ids + transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): diff --git a/comfy/ldm/util.py b/comfy/ldm/util.py index 30b4b4721..304936ff4 100644 --- a/comfy/ldm/util.py +++ b/comfy/ldm/util.py @@ -71,7 +71,7 @@ def count_params(model, verbose=False): def instantiate_from_config(config): - if not "target" in config: + if "target" not in config: if config == '__is_first_stage__': return None elif config == "__is_unconditional__": diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 4216ce831..b2287dba9 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -62,6 +62,8 @@ class WanSelfAttention(nn.Module): x(Tensor): Shape [B, L, num_heads, C / num_heads] freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] """ + patches = transformer_options.get("patches", {}) + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim def qkv_fn_q(x): @@ -86,6 +88,10 @@ class WanSelfAttention(nn.Module): transformer_options=transformer_options, ) + if "attn1_patch" in patches: + for p in patches["attn1_patch"]: + x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) + x = self.o(x) return x @@ -225,6 +231,8 @@ class WanAttentionBlock(nn.Module): """ # assert e.dtype == torch.float32 + patches = transformer_options.get("patches", {}) + if e.ndim < 4: e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1) else: @@ -242,6 +250,11 @@ class WanAttentionBlock(nn.Module): # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) + + if "attn2_patch" in patches: + for p in patches["attn2_patch"]: + x = p({"x": x, "transformer_options": transformer_options}) + y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) x = torch.addcmul(x, y, repeat_e(e[5], x)) return x @@ -488,7 +501,7 @@ class WanModel(torch.nn.Module): self.blocks = nn.ModuleList([ wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings) - for _ in range(num_layers) + for i in range(num_layers) ]) # head @@ -541,6 +554,7 @@ class WanModel(torch.nn.Module): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings @@ -738,6 +752,7 @@ class VaceWanModel(WanModel): # embeddings x = self.patch_embedding(x.float()).to(x.dtype) grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes x = x.flatten(2).transpose(1, 2) # time embeddings @@ -1606,3 +1621,118 @@ class HumoWanModel(WanModel): # unpatchify x = self.unpatchify(x, grid_sizes) return x + +class SCAILWanModel(WanModel): + def __init__(self, model_type="scail", patch_size=(1, 2, 2), in_dim=20, dim=5120, operations=None, device=None, dtype=None, **kwargs): + super().__init__(model_type='i2v', patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs) + + self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32) + + def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs): + + if reference_latent is not None: + x = torch.cat((reference_latent, x), dim=2) + + # embeddings + x = self.patch_embedding(x.float()).to(x.dtype) + grid_sizes = x.shape[2:] + transformer_options["grid_sizes"] = grid_sizes + x = x.flatten(2).transpose(1, 2) + + scail_pose_seq_len = 0 + if pose_latents is not None: + scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype) + scail_x = scail_x.flatten(2).transpose(1, 2) + scail_pose_seq_len = scail_x.shape[1] + x = torch.cat([x, scail_x], dim=1) + del scail_x + + # time embeddings + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype)) + e = e.reshape(t.shape[0], -1, e.shape[-1]) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + + # context + context = self.text_embedding(context) + + context_img_len = None + if clip_fea is not None: + if self.img_emb is not None: + context_clip = self.img_emb(clip_fea) # bs x 257 x dim + context = torch.cat([context_clip, context], dim=1) + context_img_len = clip_fea.shape[-2] + + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"]) + return out + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap}) + x = out["img"] + else: + x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, transformer_options=transformer_options) + + # head + x = self.head(x, e) + + if scail_pose_seq_len > 0: + x = x[:, :-scail_pose_seq_len] + + # unpatchify + x = self.unpatchify(x, grid_sizes) + + if reference_latent is not None: + x = x[:, :, reference_latent.shape[2]:] + + return x + + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}): + main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options) + + if pose_latents is None: + return main_freqs + + ref_t_patches = 0 + if reference_latent is not None: + ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0] + + F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1] + + # if pose is at half resolution, scale_y/scale_x=2 stretches the position range to cover the same RoPE extent as the main frames + h_scale = h / H_pose + w_scale = w / W_pose + + # 120 w-offset and shift 0.5 to place positions at midpoints (0.5, 2.5, ...) to match the original code + h_shift = (h_scale - 1) / 2 + w_shift = (w_scale - 1) / 2 + pose_transformer_options = {"rope_options": {"shift_y": h_shift, "shift_x": 120.0 + w_shift, "scale_y": h_scale, "scale_x": w_scale}} + pose_freqs = super().rope_encode(F_pose, H_pose, W_pose, t_start=t_start+ref_t_patches, device=device, dtype=dtype, transformer_options=pose_transformer_options) + + return torch.cat([main_freqs, pose_freqs], dim=1) + + def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs): + bs, c, t, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if pose_latents is not None: + pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size) + + t_len = t + if time_dim_concat is not None: + time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size) + x = torch.cat([x, time_dim_concat], dim=2) + t_len = x.shape[2] + + reference_latent = None + if "reference_latent" in kwargs: + reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size) + t_len += reference_latent.shape[2] + + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent) + return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w] diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py new file mode 100644 index 000000000..c9dd98c4d --- /dev/null +++ b/comfy/ldm/wan/model_multitalk.py @@ -0,0 +1,500 @@ +import torch +from einops import rearrange, repeat +import comfy +from comfy.ldm.modules.attention import optimized_attention + + +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8): + scale = 1.0 / visual_q.shape[-1] ** 0.5 + visual_q = visual_q.transpose(1, 2) * scale + + B, H, x_seqlens, K = visual_q.shape + + x_ref_attn_maps = [] + for class_idx, ref_target_mask in enumerate(ref_target_masks): + ref_target_mask = ref_target_mask.view(1, 1, 1, -1) + + x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype) + chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens) + + for i in range(0, x_seqlens, chunk_size): + end_i = min(i + chunk_size, x_seqlens) + + attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens + + # Apply softmax + attn_max = attn_chunk.max(dim=-1, keepdim=True).values + attn_chunk = (attn_chunk - attn_max).exp() + attn_sum = attn_chunk.sum(dim=-1, keepdim=True) + attn_chunk = attn_chunk / (attn_sum + 1e-8) + + # Apply mask and sum + masked_attn = attn_chunk * ref_target_mask + x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8) + + del attn_chunk, masked_attn + + # Average across heads + x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens + x_ref_attn_maps.append(x_ref_attnmap) + + del visual_q, ref_k + + return torch.cat(x_ref_attn_maps, dim=0) + +def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2): + """Args: + query (torch.tensor): B M H K + key (torch.tensor): B M H K + shape (tuple): (N_t, N_h, N_w) + ref_target_masks: [B, N_h * N_w] + """ + + N_t, N_h, N_w = shape + + x_seqlens = N_h * N_w + ref_k = ref_k[:, :x_seqlens] + _, seq_lens, heads, _ = visual_q.shape + class_num, _ = ref_target_masks.shape + x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q) + + split_chunk = heads // split_num + + for i in range(split_num): + x_ref_attn_maps_perhead = calculate_x_ref_attn_map( + visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :], + ref_target_masks + ) + x_ref_attn_maps += x_ref_attn_maps_perhead + + return x_ref_attn_maps / split_num + + +def normalize_and_scale(column, source_range, target_range, epsilon=1e-8): + source_min, source_max = source_range + new_min, new_max = target_range + normalized = (column - source_min) / (source_max - source_min + epsilon) + scaled = normalized * (new_max - new_min) + new_min + return scaled + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def get_audio_embeds(encoded_audio, audio_start, audio_end): + audio_embs = [] + human_num = len(encoded_audio) + audio_frames = encoded_audio[0].shape[0] + + indices = (torch.arange(4 + 1) - 2) * 1 + + for human_idx in range(human_num): + if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence + pad_len = audio_end - audio_frames + pad_shape = list(encoded_audio[human_idx].shape) + pad_shape[0] = pad_len + pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1))) + encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0) + else: + encoded_audio_in = encoded_audio[human_idx] + center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0) + center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1) + audio_emb = encoded_audio_in[center_indices].unsqueeze(0) + audio_embs.append(audio_emb) + + return torch.cat(audio_embs, dim=0) + + +def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end): + audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end) + + first_frame_audio_emb_s = audio_embs[:, :1, ...] + latter_frame_audio_emb = audio_embs[:, 1:, ...] + latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4) + + middle_index = audio_proj.seq_len // 2 + + latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] + latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] + latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] + latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") + latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) + + audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s) + audio_emb = torch.cat(audio_emb.split(1), dim=2) + + return audio_emb + + +class RotaryPositionalEmbedding1D(torch.nn.Module): + def __init__(self, + head_dim, + ): + super().__init__() + self.head_dim = head_dim + self.base = 10000 + + def precompute_freqs_cis_1d(self, pos_indices): + freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim)) + freqs = freqs.to(pos_indices.device) + freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + return freqs + + def forward(self, x, pos_indices): + freqs_cis = self.precompute_freqs_cis_1d(pos_indices) + + x_ = x.float() + + freqs_cis = freqs_cis.float().to(x.device) + cos, sin = freqs_cis.cos(), freqs_cis.sin() + cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d') + x_ = (x_ * cos) + (rotate_half(x_) * sin) + + return x_.type_as(x) + +class SingleStreamAttention(torch.nn.Module): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + device=None, dtype=None, operations=None + ) -> None: + super().__init__() + self.dim = dim + self.encoder_hidden_states_dim = encoder_hidden_states_dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor: + N_t, N_h, N_w = shape + + expected_tokens = N_t * N_h * N_w + actual_tokens = x.shape[1] + x_extra = None + + if actual_tokens != expected_tokens: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + + B = x.shape[0] + S = N_h * N_w + x = x.view(B * N_t, S, self.dim) + + # get q for hidden_state + q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim) + + # get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim) + kv = self.kv_linear(encoder_hidden_states) + encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2) + + #print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128]) + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # linear transform + x = self.proj(x.reshape(B * N_t, S, self.dim)) + x = x.view(B, N_t * S, self.dim) + + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + +class SingleStreamMultiAttention(SingleStreamAttention): + def __init__( + self, + dim: int, + encoder_hidden_states_dim: int, + num_heads: int, + qkv_bias: bool, + class_range: int = 24, + class_interval: int = 4, + device=None, dtype=None, operations=None + ) -> None: + super().__init__( + dim=dim, + encoder_hidden_states_dim=encoder_hidden_states_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + device=device, + dtype=dtype, + operations=operations + ) + + # Rotary-embedding layout parameters + self.class_interval = class_interval + self.class_range = class_range + self.max_humans = self.class_range // self.class_interval + + # Constant bucket used for background tokens + self.rope_bak = int(self.class_range // 2) + + self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim) + + def forward( + self, + x: torch.Tensor, + encoder_hidden_states: torch.Tensor, + shape=None, + x_ref_attn_map=None + ) -> torch.Tensor: + encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device) + human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1 + # Single-speaker fall-through + if human_num <= 1: + return super().forward(x, encoder_hidden_states, shape) + + N_t, N_h, N_w = shape + + x_extra = None + if x.shape[0] * N_t != encoder_hidden_states.shape[0]: + x_extra = x[:, -N_h * N_w:, :] + x = x[:, :-N_h * N_w, :] + N_t = N_t - 1 + x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t) + + # Query projection + B, N, C = x.shape + q = self.q_linear(x) + q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) + + # Use `class_range` logic for 2 speakers + rope_h1 = (0, self.class_interval) + rope_h2 = (self.class_range - self.class_interval, self.class_range) + rope_bak = int(self.class_range // 2) + + # Normalize and scale attention maps for each speaker + max_values = x_ref_attn_map.max(1).values[:, None, None] + min_values = x_ref_attn_map.min(1).values[:, None, None] + max_min_values = torch.cat([max_values, min_values], dim=2) + + human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min() + human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min() + + human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1) + human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2) + back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device) + + # Token-wise speaker dominance + max_indices = x_ref_attn_map.argmax(dim=0) + normalized_map = torch.stack([human1, human2, back], dim=1) + normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices] + + # Apply rotary to Q + q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + q = self.rope_1d(q, normalized_pos) + q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Keys / Values + _, N_a, _ = encoder_hidden_states.shape + encoder_kv = self.kv_linear(encoder_hidden_states) + encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + encoder_k, encoder_v = encoder_kv.unbind(0) + + # Rotary for keys – assign centre of each speaker bucket to its context tokens + per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device) + per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2 + per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2 + encoder_pos = torch.cat([per_frame] * N_t, dim=0) + + encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t) + encoder_k = self.rope_1d(encoder_k, encoder_pos) + encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t) + + # Final attention + q = rearrange(q, "B H M K -> B M H K") + encoder_k = rearrange(encoder_k, "B H M K -> B M H K") + encoder_v = rearrange(encoder_v, "B H M K -> B M H K") + + x = optimized_attention( + q.transpose(1, 2), + encoder_k.transpose(1, 2), + encoder_v.transpose(1, 2), + heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2) + + # Linear projection + x = x.reshape(B, N, C) + x = self.proj(x) + + # Restore original layout + x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t) + if x_extra is not None: + x = torch.cat([x, torch.zeros_like(x_extra)], dim=1) + + return x + + +class MultiTalkAudioProjModel(torch.nn.Module): + def __init__( + self, + seq_len: int = 5, + seq_len_vf: int = 12, + blocks: int = 12, + channels: int = 768, + intermediate_dim: int = 512, + out_dim: int = 768, + context_tokens: int = 32, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.seq_len = seq_len + self.blocks = blocks + self.channels = channels + self.input_dim = seq_len * blocks * channels + self.input_dim_vf = seq_len_vf * blocks * channels + self.intermediate_dim = intermediate_dim + self.context_tokens = context_tokens + self.out_dim = out_dim + + # define multiple linear layers + self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype) + self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype) + self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype) + self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype) + + def forward(self, audio_embeds, audio_embeds_vf): + video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1] + B, _, _, S, C = audio_embeds.shape + + # process audio of first frame + audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c") + batch_size, window_size, blocks, channels = audio_embeds.shape + audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels) + + # process audio of latter frame + audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c") + batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape + audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf) + + # first projection + audio_embeds = torch.relu(self.proj1(audio_embeds)) + audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf)) + audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B) + audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B) + audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1) + batch_size_c, N_t, C_a = audio_embeds_c.shape + audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a) + + # second projection + audio_embeds_c = torch.relu(self.proj2(audio_embeds_c)) + + context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim) + + # normalization and reshape + context_tokens = self.norm(context_tokens) + context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length) + + return context_tokens + + +class WanMultiTalkAttentionBlock(torch.nn.Module): + def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None): + super().__init__() + self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations) + self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True) + + +class MultiTalkGetAttnMapPatch: + def __init__(self, ref_target_masks=None): + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + x = kwargs["x"] + + if self.ref_target_masks is not None: + x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device)) + transformer_options["x_ref_attn_map"] = x_ref_attn_map + return x + + +class MultiTalkCrossAttnPatch: + def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None): + self.model_patch = model_patch + self.audio_scale = audio_scale + self.ref_target_masks = ref_target_masks + + def __call__(self, kwargs): + transformer_options = kwargs.get("transformer_options", {}) + block_idx = transformer_options.get("block_index", None) + x = kwargs["x"] + if block_idx is None: + return torch.zeros_like(x) + + audio_embeds = transformer_options.get("audio_embeds") + x_ref_attn_map = transformer_options.pop("x_ref_attn_map", None) + + norm_x = self.model_patch.model.blocks[block_idx].norm_x(x) + x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn( + norm_x, audio_embeds.to(x.dtype), + shape=transformer_options["grid_sizes"], + x_ref_attn_map=x_ref_attn_map + ) + x = x + x_audio * self.audio_scale + return x + + def models(self): + return [self.model_patch] + +class MultiTalkApplyModelWrapper: + def __init__(self, init_latents): + self.init_latents = init_latents + + def __call__(self, executor, x, *args, **kwargs): + x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x) + samples = executor(x, *args, **kwargs) + return samples + + +class InfiniteTalkOuterSampleWrapper: + def __init__(self, motion_frames_latent, model_patch, is_extend=False): + self.motion_frames_latent = motion_frames_latent + self.model_patch = model_patch + self.is_extend = is_extend + + def __call__(self, executor, *args, **kwargs): + model_patcher = executor.class_obj.model_patcher + model_options = executor.class_obj.model_options + process_latent_in = model_patcher.model.process_latent_in + + # for InfiniteTalk, model input first latent(s) need to always be replaced on every step + if self.motion_frames_latent is not None: + wrappers = model_options["transformer_options"]["wrappers"] + w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {}) + w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(self.motion_frames_latent))] + + # run the sampling process + result = executor(*args, **kwargs) + + # insert motion frames before decoding + if self.is_extend: + overlap = self.motion_frames_latent.shape[2] + result = torch.cat([self.motion_frames_latent.to(result), result[:, :, overlap:]], dim=2) + + return result + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + if self.motion_frames_latent is not None: + self.motion_frames_latent = self.motion_frames_latent.to(device_or_dtype) + return self diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 08315f1a8..57b0dabf7 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from comfy.ldm.modules.diffusionmodules.model import vae_attention +from comfy.ldm.modules.diffusionmodules.model import vae_attention, torch_cat_if_needed import comfy.ops ops = comfy.ops.disable_weight_init @@ -20,22 +20,29 @@ class CausalConv3d(ops.Conv3d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) + self._padding = 2 * self.padding[0] + self.padding = (0, self.padding[1], self.padding[2]) def forward(self, x, cache_x=None, cache_list=None, cache_idx=None): if cache_list is not None: cache_x = cache_list[cache_idx] cache_list[cache_idx] = None - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] + if cache_x is None and x.shape[2] == 1: + #Fast path - the op will pad for use by truncating the weight + #and save math on a pile of zeros. + return super().forward(x, autopad="causal_zero") + + if self._padding > 0: + padding_needed = self._padding + if cache_x is not None: + cache_x = cache_x.to(x.device) + padding_needed = max(0, padding_needed - cache_x.shape[2]) + padding_shape = list(x.shape) + padding_shape[2] = padding_needed + padding = torch.zeros(padding_shape, device=x.device, dtype=x.dtype) + x = torch_cat_if_needed([padding, cache_x, x], dim=2) del cache_x - x = F.pad(x, padding) return super().forward(x) @@ -92,7 +99,7 @@ class Resample(nn.Module): else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): b, c, t, h, w = x.size() if self.mode == 'upsample3d': if feat_cache is not None: @@ -102,22 +109,7 @@ class Resample(nn.Module): feat_idx[0] += 1 else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: @@ -138,19 +130,24 @@ class Resample(nn.Module): if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 + feat_cache[idx] = x else: - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - + cache_x = x[:, :, -1:, :, :] x = self.time_conv( torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x - feat_idx[0] += 1 + + deferred_x = feat_cache[idx + 1] + if deferred_x is not None: + x = torch.cat([deferred_x, x], 2) + feat_cache[idx + 1] = None + + if x.shape[2] == 1 and not final: + feat_cache[idx + 1] = x + x = None + + feat_idx[0] += 2 return x @@ -170,19 +167,12 @@ class ResidualBlock(nn.Module): self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ if in_dim != out_dim else nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -206,7 +196,7 @@ class AttentionBlock(nn.Module): self.proj = ops.Conv2d(dim, dim, 1) self.optimized_attention = vae_attention() - def forward(self, x): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): identity = x b, c, t, h, w = x.size() x = rearrange(x, 'b c t h w -> (b t) c h w') @@ -276,17 +266,10 @@ class Encoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -296,14 +279,16 @@ class Encoder3d(nn.Module): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache, feat_idx, final=final) + if x is None: + return None else: x = layer(x) ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, final=final) else: x = layer(x) @@ -311,14 +296,7 @@ class Encoder3d(nn.Module): for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -382,18 +360,48 @@ class Decoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, output_channels, 3, padding=1)) + def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks): + x = x_ref[0] + x_ref[0] = None + if layer_idx >= len(self.upsamples): + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :] + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) + return + + layer = self.upsamples[layer_idx] + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2: + for frame_idx in range(0, x.shape[2], 2): + self.run_up( + layer_idx + 1, + [x[:, :, frame_idx:frame_idx + 2, :, :]], + feat_cache, + feat_idx.copy(), + out_chunks, + ) + del x + return + + next_x_ref = [x] + del x + self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks) + def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -402,42 +410,21 @@ class Decoder3d(nn.Module): ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## upsamples - for layer in self.upsamples: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x + out_chunks = [] + + self.run_up(0, [x], feat_cache, feat_idx, out_chunks) + return out_chunks -def count_conv3d(model): +def count_cache_layers(model): count = 0 for m in model.modules(): - if isinstance(m, CausalConv3d): + if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'): count += 1 return count @@ -452,6 +439,7 @@ class WanVAE(nn.Module): attn_scales=[], temperal_downsample=[True, True, False], image_channels=3, + conv_out_channels=3, dropout=0.0): super().__init__() self.dim = dim @@ -467,16 +455,19 @@ class WanVAE(nn.Module): attn_scales, self.temperal_downsample, dropout) self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, image_channels, dim_mult, num_res_blocks, + self.decoder = Decoder3d(dim, z_dim, conv_out_channels, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout) def encode(self, x): conv_idx = [0] - feat_map = [None] * count_conv3d(self.decoder) ## cache t = x.shape[2] - iter_ = 1 + (t - 1) // 4 - ## 对encode输入的x,按时间拆分为1、4、4、4.... + t = 1 + ((t - 1) // 4) * 4 + iter_ = 1 + (t - 1) // 2 + feat_map = None + if iter_ > 1: + feat_map = [None] * count_cache_layers(self.encoder) + ## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整) for i in range(iter_): conv_idx = [0] if i == 0: @@ -486,19 +477,23 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, - feat_idx=conv_idx) + feat_idx=conv_idx, + final=(i == (iter_ - 1))) + if out_ is None: + continue out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) return mu def decode(self, z): - conv_idx = [0] - feat_map = [None] * count_conv3d(self.decoder) # z: [b,c,t,h,w] - - iter_ = z.shape[2] + iter_ = 1 + z.shape[2] // 2 + feat_map = None + if iter_ > 1: + feat_map = [None] * count_cache_layers(self.decoder) x = self.conv2(z) for i in range(iter_): conv_idx = [0] @@ -509,8 +504,8 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.decoder( - x[:, :, i:i + 1, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, feat_idx=conv_idx) - out = torch.cat([out, out_], 2) - return out + out += out_ + return torch.cat(out, 2) diff --git a/comfy/lora.py b/comfy/lora.py index 2ed0acb9d..63ee85323 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}): for k in sdk: if k.endswith(".weight"): key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names + tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models + if tp > 0 and not k.startswith("clip_"): + key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" clip_l_present = False @@ -260,6 +263,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer + key_map[k[:-len(".weight")]] = to #DiffSynth lora format for k in sdk: hidden_size = model.model_config.unet_config.get("hidden_size", 0) if k.endswith(".weight") and ".linear1." in k: @@ -322,6 +326,7 @@ def model_lora_keys_unet(model, key_map={}): key_map["diffusion_model.{}".format(key_lora)] = to key_map["transformer.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + key_map[key_lora] = to if isinstance(model, comfy.model_base.Kandinsky5): for k in sdk: @@ -330,6 +335,13 @@ def model_lora_keys_unet(model, key_map={}): key_map["{}".format(key_lora)] = k key_map["transformer.{}".format(key_lora)] = k + if isinstance(model, comfy.model_base.ACEStep15): + for k in sdk: + if k.startswith("diffusion_model.decoder.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model.decoder."):-len(".weight")] + key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format + return key_map @@ -366,6 +378,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten return padded_tensor +def calculate_shape(patches, weight, key, original_weights=None): + current_shape = weight.shape + + for p in patches: + v = p[1] + offset = p[3] + + # Offsets restore the old shape; lists force a diff without metadata + if offset is not None or isinstance(v, list): + continue + + if isinstance(v, weight_adapter.WeightAdapterBase): + adapter_shape = v.calculate_shape(key) + if adapter_shape is not None: + current_shape = adapter_shape + continue + + # Standard diff logic with padding + if len(v) == 2: + patch_type, patch_data = v[0], v[1] + if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']: + current_shape = patch_data[0].shape + + return current_shape + def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None): for p in patches: strength = p[0] diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 9d8d21efe..749e81df3 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -5,7 +5,7 @@ import comfy.utils def convert_lora_bfl_control(sd): #BFL loras for Flux sd_out = {} for k in sd: - k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight")) + k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight")) sd_out[k_to] = sd[k] sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]]) diff --git a/comfy/memory_management.py b/comfy/memory_management.py new file mode 100644 index 000000000..f9078fe7c --- /dev/null +++ b/comfy/memory_management.py @@ -0,0 +1,143 @@ +import math +import ctypes +import threading +import dataclasses +import torch +from typing import NamedTuple + +from comfy.quant_ops import QuantizedTensor + + +class TensorFileSlice(NamedTuple): + file_ref: object + thread_id: int + offset: int + size: int + + +def read_tensor_file_slice_into(tensor, destination): + + if isinstance(tensor, QuantizedTensor): + if not isinstance(destination, QuantizedTensor): + return False + if tensor._layout_cls != destination._layout_cls: + return False + + if not read_tensor_file_slice_into(tensor._qdata, destination._qdata): + return False + + dst_orig_dtype = destination._params.orig_dtype + destination._params.copy_from(tensor._params, non_blocking=False) + destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) + return True + + info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) + if info is None: + return False + + file_obj = info.file_ref + if (destination.device.type != "cpu" + or file_obj is None + or threading.get_ident() != info.thread_id + or destination.numel() * destination.element_size() < info.size + or tensor.numel() * tensor.element_size() != info.size + or tensor.storage_offset() != 0 + or not tensor.is_contiguous()): + return False + + if info.size == 0: + return True + + buf_type = ctypes.c_ubyte * info.size + view = memoryview(buf_type.from_address(destination.data_ptr())) + + try: + file_obj.seek(info.offset) + done = 0 + while done < info.size: + try: + n = file_obj.readinto(view[done:]) + except OSError: + return False + if n <= 0: + return False + done += n + return True + finally: + view.release() + +class TensorGeometry(NamedTuple): + shape: any + dtype: torch.dtype + + def element_size(self): + info = torch.finfo(self.dtype) if self.dtype.is_floating_point else torch.iinfo(self.dtype) + return info.bits // 8 + + def numel(self): + return math.prod(self.shape) + +def tensors_to_geometries(tensors, dtype=None): + geometries = [] + for t in tensors: + if t is None or isinstance(t, QuantizedTensor): + geometries.append(t) + continue + tdtype = t.dtype + if hasattr(t, "_model_dtype"): + tdtype = t._model_dtype + if dtype is not None: + tdtype = dtype + geometries.append(TensorGeometry(shape=t.shape, dtype=tdtype)) + return geometries + +def vram_aligned_size(tensor): + if isinstance(tensor, list): + return sum([vram_aligned_size(t) for t in tensor]) + + if isinstance(tensor, QuantizedTensor): + inner_tensors, _ = tensor.__tensor_flatten__() + return vram_aligned_size([ getattr(tensor, attr) for attr in inner_tensors ]) + + if tensor is None: + return 0 + + size = tensor.numel() * tensor.element_size() + aligment_req = 1024 + return (size + aligment_req - 1) // aligment_req * aligment_req + +def interpret_gathered_like(tensors, gathered): + offset = 0 + dest_views = [] + + if gathered.dim() != 1 or gathered.element_size() != 1: + raise ValueError(f"Buffer must be 1D and single-byte (got {gathered.dim()}D {gathered.dtype})") + + for tensor in tensors: + + if tensor is None: + dest_views.append(None) + continue + + if isinstance(tensor, QuantizedTensor): + inner_tensors, qt_ctx = tensor.__tensor_flatten__() + templates = { attr: getattr(tensor, attr) for attr in inner_tensors } + else: + templates = { "data": tensor } + + actuals = {} + for attr, template in templates.items(): + size = template.numel() * template.element_size() + if offset + size > gathered.numel(): + raise ValueError(f"Buffer too small: needs {offset + size} bytes, but only has {gathered.numel()}. ") + actuals[attr] = gathered[offset:offset+size].view(dtype=template.dtype).view(template.shape) + offset += vram_aligned_size(template) + + if isinstance(tensor, QuantizedTensor): + dest_views.append(QuantizedTensor.__tensor_unflatten__(actuals, qt_ctx, 0, 0)) + else: + dest_views.append(actuals["data"]) + + return dest_views + +aimdo_enabled = False diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b8a8454d..70aff886e 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,8 @@ import comfy.ldm.hunyuan3dv2_1 import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import comfy.ldm.lightricks.av_model +import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -48,6 +50,8 @@ import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model +import comfy.ldm.anima.model +import comfy.ldm.ace.ace_step15 import comfy.model_management import comfy.patcher_extension @@ -73,6 +77,7 @@ class ModelType(Enum): FLUX = 8 IMG_TO_IMG = 9 FLOW_COSMOS = 10 + IMG_TO_IMG_FLOW = 11 def model_sampling(model_config, model_type): @@ -105,6 +110,8 @@ def model_sampling(model_config, model_type): elif model_type == ModelType.FLOW_COSMOS: c = comfy.model_sampling.COSMOS_RFLOW s = comfy.model_sampling.ModelSamplingCosmosRFlow + elif model_type == ModelType.IMG_TO_IMG_FLOW: + c = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSampling(s, c): pass @@ -144,6 +151,8 @@ class BaseModel(torch.nn.Module): self.diffusion_model.to(memory_format=torch.channels_last) logging.debug("using channels last mode for diffusion model") logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype)) + comfy.model_management.archive_model_dtypes(self.diffusion_model) + self.model_type = model_type self.model_sampling = model_sampling(model_config, model_type) @@ -173,10 +182,7 @@ class BaseModel(torch.nn.Module): xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) context = c_crossattn - dtype = self.get_dtype() - - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype + dtype = self.get_dtype_inference() xc = xc.to(dtype) device = xc.device @@ -213,6 +219,13 @@ class BaseModel(torch.nn.Module): def get_dtype(self): return self.diffusion_model.dtype + def get_dtype_inference(self): + dtype = self.get_dtype() + + if self.manual_cast_dtype is not None: + dtype = self.manual_cast_dtype + return dtype + def encode_adm(self, **kwargs): return None @@ -273,6 +286,12 @@ class BaseModel(torch.nn.Module): return data return None + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + """Override in subclasses to handle model-specific cond slicing for context windows. + Return a sliced cond object, or None to fall through to default handling. + Use comfy.context_windows.slice_cond() for common cases.""" + return None + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -297,7 +316,7 @@ class BaseModel(torch.nn.Module): return out - def load_model_weights(self, sd, unet_prefix=""): + def load_model_weights(self, sd, unet_prefix="", assign=False): to_load = {} keys = list(sd.keys()) for k in keys: @@ -305,7 +324,7 @@ class BaseModel(torch.nn.Module): to_load[k[len(unet_prefix):]] = sd.pop(k) to_load = self.model_config.process_unet_state_dict(to_load) - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=assign) if len(m) > 0: logging.warning("unet missing: {}".format(m)) @@ -320,7 +339,7 @@ class BaseModel(torch.nn.Module): def process_latent_out(self, latent): return self.latent_format.process_out(latent) - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): extra_sds = [] if clip_state_dict is not None: extra_sds.append(self.model_config.process_clip_state_dict_for_saving(clip_state_dict)) @@ -328,10 +347,7 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_vae_state_dict_for_saving(vae_state_dict)) if clip_vision_state_dict is not None: extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) - - unet_state_dict = self.diffusion_model.state_dict() unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) - if self.model_type == ModelType.V_PREDICTION: unet_state_dict["v_pred"] = torch.tensor([]) @@ -370,9 +386,7 @@ class BaseModel(torch.nn.Module): input_shapes += shape if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): - dtype = self.get_dtype() - if self.manual_cast_dtype is not None: - dtype = self.manual_cast_dtype + dtype = self.get_dtype_inference() #TODO: this needs to be tweaked area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) @@ -774,8 +788,8 @@ class StableAudio1(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out - def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): - sd = super().state_dict_for_saving(clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) d = {"conditioner.conditioners.seconds_start.": self.seconds_start_embedder.state_dict(), "conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} for k in d: s = d[k] @@ -918,6 +932,26 @@ class Flux(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class LongCatImage(Flux): + def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): + transformer_options = transformer_options.copy() + rope_opts = transformer_options.get("rope_options", {}) + rope_opts = dict(rope_opts) + pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0 + rope_opts.setdefault("shift_t", 1.0) + rope_opts.setdefault("shift_y", pe_len) + rope_opts.setdefault("shift_x", pe_len) + transformer_options["rope_options"] = rope_opts + return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) + + def encode_adm(self, **kwargs): + return None + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out.pop('guidance', None) + return out + class Flux2(Flux): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -946,7 +980,7 @@ class GenmoMochi(BaseModel): class LTXV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -967,6 +1001,10 @@ class LTXV(BaseModel): if keyframe_idxs is not None: out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + return out def process_timestep(self, timestep, x, denoise_mask=None, **kwargs): @@ -977,6 +1015,72 @@ class LTXV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class LTXAV(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + device = kwargs["device"] + + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + if hasattr(self.diffusion_model, "preprocess_text_embeds"): + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False)) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + + audio_denoise_mask = None + if denoise_mask is not None and "latent_shapes" in kwargs: + denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"]) + if len(denoise_mask) > 1: + audio_denoise_mask = denoise_mask[1] + denoise_mask = denoise_mask[0] + + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + if audio_denoise_mask is not None: + out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask) + + keyframe_idxs = kwargs.get("keyframe_idxs", None) + if keyframe_idxs is not None: + out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + + latent_shapes = kwargs.get("latent_shapes", None) + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + + guide_attention_entries = kwargs.get("guide_attention_entries", None) + if guide_attention_entries is not None: + out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) + + ref_audio = kwargs.get("ref_audio", None) + if ref_audio is not None: + out['ref_audio'] = comfy.conds.CONDConstant(ref_audio) + + return out + + def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): + v_timestep = timestep + a_timestep = timestep + + if denoise_mask is not None: + v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0] + if audio_denoise_mask is not None: + a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0] + + return v_timestep, a_timestep + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) @@ -1092,9 +1196,35 @@ class CosmosPredict2(BaseModel): sigma = (sigma / (sigma + 1)) return latent_image / (1.0 - sigma) +class Anima(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.anima.model.Anima) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + t5xxl_ids = kwargs.get("t5xxl_ids", None) + t5xxl_weights = kwargs.get("t5xxl_weights", None) + device = kwargs["device"] + if cross_attn is not None: + if t5xxl_ids is not None: + if t5xxl_weights is not None: + t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn) + t5xxl_ids = t5xxl_ids.unsqueeze(0) + + if torch.is_inference_mode_enabled(): # if not we are training + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference())) + else: + out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids) + out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) + + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class Lumina2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT) + self.memory_usage_factor_conds = ("ref_latents",) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -1110,12 +1240,46 @@ class Lumina2(BaseModel): if 'num_tokens' not in out: out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1]) - clip_text_pooled = kwargs["pooled_output"] # Newbie + clip_text_pooled = kwargs.get("pooled_output", None) # NewBie if clip_text_pooled is not None: out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled) + clip_vision_outputs = kwargs.get("clip_vision_outputs", list(map(lambda a: a.get("clip_vision_output"), kwargs.get("unclip_conditioning", [{}])))) # Z Image omni + if clip_vision_outputs is not None and len(clip_vision_outputs) > 0: + sigfeats = [] + for clip_vision_output in clip_vision_outputs: + if clip_vision_output is not None: + image_size = clip_vision_output.image_sizes[0] + shape = clip_vision_output.last_hidden_state.shape + sigfeats.append(clip_vision_output.last_hidden_state.reshape(shape[0], image_size[1] // 16, image_size[2] // 16, shape[-1])) + if len(sigfeats) > 0: + out['siglip_feats'] = comfy.conds.CONDList(sigfeats) + + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + latents = [] + for lat in ref_latents: + latents.append(self.process_latent_in(lat)) + out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_contexts = kwargs.get("reference_latents_text_embeds", None) + if ref_contexts is not None: + out['ref_contexts'] = comfy.conds.CONDList(ref_contexts) + return out + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) + return out + +class ZImagePixelSpace(Lumina2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + self.memory_usage_factor_conds = ("ref_latents",) + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) @@ -1223,6 +1387,11 @@ class WAN21_Vace(WAN21): out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "vace_context": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_Camera(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) @@ -1275,6 +1444,11 @@ class WAN21_HuMo(WAN21): return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_Animate(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) @@ -1292,6 +1466,13 @@ class WAN22_Animate(WAN21): out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "face_pixel_values": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1) + if cond_key == "pose_latents": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) @@ -1328,6 +1509,11 @@ class WAN22_S2V(WAN21): out['reference_motion'] = reference_motion.shape return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) @@ -1349,6 +1535,50 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class WAN21_FlowRVS(WAN21): + def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): + model_config.unet_config["model_type"] = "t2v" + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) + self.image_to_video = image_to_video + +class WAN21_SCAIL(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAILWanModel) + self.memory_usage_factor_conds = ("reference_latent", "pose_latents") + self.memory_usage_shape_process = {"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]]} + self.image_to_video = image_to_video + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + reference_latents = kwargs.get("reference_latents", None) + if reference_latents is not None: + ref_latent = self.process_latent_in(reference_latents[-1]) + ref_mask = torch.ones_like(ref_latent[:, :4]) + ref_latent = torch.cat([ref_latent, ref_mask], dim=1) + out['reference_latent'] = comfy.conds.CONDRegular(ref_latent) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + pose_latents = self.process_latent_in(pose_latents) + pose_mask = torch.ones_like(pose_latents[:, :4]) + pose_latents = torch.cat([pose_latents, pose_mask], dim=1) + out['pose_latents'] = comfy.conds.CONDRegular(pose_latents) + + return out + + def extra_conds_shapes(self, **kwargs): + out = {} + ref_latents = kwargs.get("reference_latents", None) + if ref_latents is not None: + out['reference_latent'] = list([1, 20, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + + pose_latents = kwargs.get("pose_video_latent", None) + if pose_latents is not None: + out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] + + return out + class Hunyuan3Dv2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2) @@ -1434,6 +1664,49 @@ class ACEStep(BaseModel): out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0)) return out +class ACEStep15(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.ace_step15.AceStepConditionGenerationModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + device = kwargs["device"] + noise = kwargs["noise"] + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + if torch.count_nonzero(cross_attn) == 0: + out['replace_with_null_embeds'] = comfy.conds.CONDConstant(True) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + conditioning_lyrics = kwargs.get("conditioning_lyrics", None) + if cross_attn is not None: + out['lyric_embed'] = comfy.conds.CONDRegular(conditioning_lyrics) + + refer_audio = kwargs.get("reference_audio_timbre_latents", None) + if refer_audio is None or len(refer_audio) == 0: + refer_audio = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device) + pass_audio_codes = True + else: + refer_audio = refer_audio[-1][:, :, :noise.shape[2]] + out['is_covers'] = comfy.conds.CONDConstant(True) + pass_audio_codes = False + + if pass_audio_codes: + audio_codes = kwargs.get("audio_codes", None) + if audio_codes is not None: + out['audio_codes'] = comfy.conds.CONDRegular(torch.tensor(audio_codes, device=device)) + refer_audio = refer_audio[:, :, :750] + else: + out['is_covers'] = comfy.conds.CONDConstant(False) + + if refer_audio.shape[2] < noise.shape[2]: + pad = comfy.ldm.ace.ace_step15.get_silence_latent(noise.shape[2], device) + refer_audio = torch.cat([refer_audio.to(pad), pad[:, :, refer_audio.shape[2]:]], dim=2) + + out['refer_audio'] = comfy.conds.CONDRegular(refer_audio) + return out + class Omnigen2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.omnigen.omnigen2.OmniGen2Transformer2DModel) @@ -1471,6 +1744,9 @@ class QwenImage(BaseModel): def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 84fd409fd..35a6822e3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -1,4 +1,5 @@ import json +import comfy.memory_management import comfy.supported_models import comfy.supported_models_base import comfy.utils @@ -19,6 +20,12 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count +def any_suffix_in(keys, prefix, main, suffix_list=[]): + for x in suffix_list: + if "{}{}{}".format(prefix, main, x) in keys: + return True + return False + def calculate_transformer_depth(prefix, state_dict_keys, state_dict): context_dim = None use_linear_in_transformer = False @@ -186,7 +193,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["meanflow_sum"] = False return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) + if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: dit_config["image_model"] = "flux2" @@ -237,9 +244,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["vec_in_dim"] = None + dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"]) + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma + + if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma dit_config["image_model"] = "chroma" dit_config["in_channels"] = 64 dit_config["out_channels"] = 64 @@ -247,17 +257,18 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 - if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance + + if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance dit_config["image_model"] = "chroma_radiance" dit_config["in_channels"] = 3 dit_config["out_channels"] = 3 - dit_config["patch_size"] = 16 + dit_config["patch_size"] = state_dict.get('{}img_in_patch.weight'.format(key_prefix)).size(dim=-1) dit_config["nerf_hidden_size"] = 64 dit_config["nerf_mlp_ratio"] = 4 dit_config["nerf_depth"] = 4 dit_config["nerf_max_freqs"] = 8 dit_config["nerf_tile_size"] = 512 - dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" + dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred dit_config["use_x0"] = True @@ -266,9 +277,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys - dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys + dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"]) if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model dit_config["txt_ids_dims"] = [1, 2] + if dit_config.get("context_in_dim") == 3584 and dit_config["vec_in_dim"] is None: # LongCat-Image + dit_config["txt_ids_dims"] = [1, 2] return dit_config @@ -305,7 +318,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} - dit_config["image_model"] = "ltxv" + dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv" dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape dit_config["attention_head_dim"] = shape[0] // 32 @@ -411,7 +424,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config - if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 @@ -430,8 +443,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["rope_theta"] = 10000.0 dit_config["ffn_dim_multiplier"] = 4.0 ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) - if ctd_weight is not None: + if ctd_weight is not None: # NewBie dit_config["clip_text_dim"] = ctd_weight.shape[0] + # NewBie also sets axes_lens = [1024, 512, 512] but it's not used in ComfyUI elif dit_config["dim"] == 3840: # Z image dit_config["n_heads"] = 30 dit_config["n_kv_heads"] = 30 @@ -441,8 +455,38 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) dit_config["z_image_modulation"] = True dit_config["time_scale"] = 1000.0 + try: + dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42 + except Exception: + pass if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: dit_config["pad_tokens_multiple"] = 32 + sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None) + if sig_weight is not None: + dit_config["siglip_feat_dim"] = sig_weight.shape[0] + + dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) + if dec_cond_key in state_dict_keys: # pixel-space variant + dit_config["image_model"] = "zimage_pixel" + # patch_size and in_channels are derived from x_embedder: + # x_embedder: Linear(patch_size * patch_size * in_channels, dim) + # The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim. + x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] + dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0] + # patch_size: infer from decoder final layer output matching x_embedder input + # in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2) + embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)] + dec_in_ch = dec_out # decoder in == decoder out (same pixel space) + dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3) + dit_config["in_channels"] = 3 + dit_config["decoder_in_channels"] = dec_in_ch + dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0] + dit_config["decoder_num_res_blocks"] = count_blocks( + state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' + ) + dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5) + if '{}__x0__'.format(key_prefix) in state_dict_keys: + dit_config["use_x0"] = True return dit_config @@ -478,6 +522,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["model_type"] = "humo" elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "animate" + elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys: + dit_config["model_type"] = "scail" else: if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys: dit_config["model_type"] = "i2v" @@ -491,6 +537,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if ref_conv_weight is not None: dit_config["in_dim_ref_conv"] = ref_conv_weight.shape[1] + if metadata is not None and "config" in metadata: + dit_config.update(json.loads(metadata["config"]).get("transformer", {})) + return dit_config if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D @@ -508,8 +557,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config - if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 - + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1 dit_config = {} dit_config["image_model"] = "hunyuan3d2_1" dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] @@ -544,6 +592,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}blocks.0.mlp.layer1.weight'.format(key_prefix) in state_dict_keys: # Cosmos predict2 dit_config = {} dit_config["image_model"] = "cosmos_predict2" + if "{}llm_adapter.blocks.0.cross_attn.q_proj.weight".format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "anima" dit_config["max_img_h"] = 240 dit_config["max_img_w"] = 240 dit_config["max_frames"] = 128 @@ -643,6 +693,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') return dit_config + if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys: + dit_config = {} + dit_config["audio_model"] = "ace1.5" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -767,6 +822,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["use_temporal_resblock"] = False unet_config["use_temporal_attention"] = False + heatmap_key = '{}heatmap_head.conv_layers.0.weight'.format(key_prefix) + if heatmap_key in state_dict_keys: + unet_config["heatmap_head"] = True + return unet_config def model_config_from_unet_config(unet_config, state_dict=None): @@ -987,7 +1046,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], - 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8, + 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_head_channels': 64, 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], 'use_temporal_attention': False, 'use_temporal_resblock': False} @@ -1019,6 +1078,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix) + elif 'noise_refiner.0.attention.norm_k.weight' in state_dict: + n_layers = count_blocks(state_dict, 'layers.{}.') + dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0] + sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix) + for k in state_dict: # For zeta chroma + if k not in sd_map: + sd_map[k] = k elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') @@ -1053,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): new[:old_weight.shape[0]] = old_weight old_weight = new + if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled: + old_weight = old_weight.clone() + w = old_weight.narrow(offset[0], offset[1], offset[2]) else: + if comfy.memory_management.aimdo_enabled: + weight = weight.clone() old_weight = weight w = weight w[:] = fun(weight) diff --git a/comfy/model_management.py b/comfy/model_management.py index 40717b1e4..9617d8388 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -20,12 +20,17 @@ import psutil import logging from enum import Enum from comfy.cli_args import args, PerformanceFeature +import threading import torch import sys -import importlib import platform import weakref import gc +import os +from contextlib import nullcontext +import comfy.memory_management +import comfy.utils +import comfy.quant_ops class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -47,6 +52,12 @@ cpu_state = CPUState.GPU total_vram = 0 + +# Training Related State +in_training = False +training_fp8_bwd = False + + def get_supported_float8_types(): float8_types = [] try: @@ -167,6 +178,14 @@ def is_ixuca(): return True return False +def is_wsl(): + version = platform.uname().release + if version.endswith("-Microsoft"): + return True + elif version.endswith("microsoft-standard-WSL2"): + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -252,6 +271,23 @@ try: except: OOM_EXCEPTION = Exception +try: + ACCELERATOR_ERROR = torch.AcceleratorError +except AttributeError: + ACCELERATOR_ERROR = RuntimeError + +def is_oom(e): + if isinstance(e, OOM_EXCEPTION): + return True + if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()): + discard_cuda_async_error() + return True + return False + +def raise_non_oom(e): + if not is_oom(e): + raise e + XFORMERS_VERSION = "" XFORMERS_ENABLED_VAE = True if args.disable_xformers: @@ -333,28 +369,42 @@ except: SUPPORT_FP8_OPS = args.supports_fp8_compute AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] +AMD_ENABLE_MIOPEN_ENV = 'COMFYUI_ENABLE_MIOPEN' try: if is_amd(): - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName.split(':')[0] if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): - torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD - logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + if os.getenv(AMD_ENABLE_MIOPEN_ENV) != '1': + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) except: rocm_version = (6, -1) + def aotriton_supported(gpu_arch): + path = torch.__path__[0] + path = os.path.join(os.path.join(path, "lib"), "aotriton.images") + gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path)))) + if gpu_arch in gfx: + return True + if "{}x".format(gpu_arch[:-1]) in gfx: + return True + if "{}xx".format(gpu_arch[:-2]) in gfx: + return True + return False + logging.info("AMD arch: {}".format(arch)) logging.info("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: - if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not. + if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): - if any((a in arch) for a in ["gfx1201"]): + if any((a in arch) for a in ["gfx1200", "gfx1201"]): ENABLE_PYTORCH_ATTENTION = True if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4): if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0 @@ -453,9 +503,31 @@ def module_size(module): sd = module.state_dict() for k in sd: t = sd[k] - module_mem += t.nelement() * t.element_size() + module_mem += t.nbytes return module_mem +def module_mmap_residency(module, free=False): + mmap_touched_mem = 0 + module_mem = 0 + bounced_mmaps = set() + sd = module.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nbytes + storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() + if not getattr(storage, "_comfy_tensor_mmap_touched", False): + continue + mmap_touched_mem += t.nbytes + if not free: + continue + storage._comfy_tensor_mmap_touched = False + mmap_obj = storage._comfy_tensor_mmap_refs[0] + if mmap_obj in bounced_mmaps: + continue + mmap_obj.bounce() + bounced_mmaps.add(mmap_obj) + return mmap_touched_mem, module_mem + class LoadedModel: def __init__(self, model): self._set_model(model) @@ -470,6 +542,7 @@ class LoadedModel: if model.parent is not None: self._parent_model = weakref.ref(model.parent) self._patcher_finalizer = weakref.finalize(model, self._switch_parent) + self._patcher_finalizer.atexit = False def _switch_parent(self): model = self._parent_model() @@ -483,6 +556,9 @@ class LoadedModel: def model_memory(self): return self.model.model_size() + def model_mmap_residency(self, free=False): + return self.model.model_mmap_residency(free=free) + def model_loaded_memory(self): return self.model.loaded_size() @@ -513,6 +589,7 @@ class LoadedModel: self.real_model = weakref.ref(real_model) self.model_finalizer = weakref.finalize(real_model, cleanup_models) + self.model_finalizer.atexit = False return real_model def should_reload_model(self, force_patch_weights=False): @@ -564,9 +641,15 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: + import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 + def get_free_ram(): + return comfy.windows.get_free_ram() +else: + def get_free_ram(): + return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -578,7 +661,7 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[]): +def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -591,17 +674,34 @@ def free_memory(memory_required, device, keep_loaded=[]): can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False - for x in sorted(can_unload): + can_unload_sorted = sorted(can_unload) + for x in can_unload_sorted: i = x[-1] - memory_to_free = None + memory_to_free = 1e32 + pins_to_free = 1e32 if not DISABLE_SMART_MEMORY: - free_mem = get_free_memory(device) - if free_mem > memory_required: - break - memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") - if current_loaded_models[i].model_unload(memory_to_free): + memory_to_free = memory_required - get_free_memory(device) + pins_to_free = pins_required - get_free_ram() + if current_loaded_models[i].model.is_dynamic() and for_dynamic: + #don't actually unload dynamic models for the sake of other dynamic models + #as that works on-demand. + memory_required -= current_loaded_models[i].model.loaded_size() + memory_to_free = 0 + if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) + if pins_to_free > 0: + logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}") + current_loaded_models[i].model.partially_unload_ram(pins_to_free) + + for x in can_unload_sorted: + i = x[-1] + ram_to_free = ram_required - psutil.virtual_memory().available + if ram_to_free <= 0 and i not in unloaded_model: + continue + resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True) + if resident_memory > 0: + logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -636,7 +736,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] + free_for_dynamic=True for x in models: + if not x.is_dynamic(): + free_for_dynamic = False loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) @@ -662,19 +765,35 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() + total_memory_required = {} + total_pins_required = {} + total_ram_required = {} for loaded_model in models_to_load: - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) + device = loaded_model.device + total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) + resident_memory, model_memory = loaded_model.model_mmap_residency() + pinned_memory = loaded_model.model.pinned_memory_size() + #FIXME: This can over-free the pins as it budgets to pin the entire model. We should + #make this JIT to keep as much pinned as possible. + pins_required = model_memory - pinned_memory + ram_required = model_memory - resident_memory + total_pins_required[device] = total_pins_required.get(device, 0) + pins_required + total_ram_required[device] = total_ram_required.get(device, 0) + ram_required for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device) + free_memory(total_memory_required[device] * 1.1 + extra_mem, + device, + for_dynamic=free_for_dynamic, + pins_required=total_pins_required[device], + ram_required=total_ram_required[device]) for device in total_memory_required: if device != torch.device("cpu"): free_mem = get_free_memory(device) if free_mem < minimum_memory_required: - models_l = free_memory(minimum_memory_required, device) + models_l = free_memory(minimum_memory_required, device, for_dynamic=free_for_dynamic) logging.info("{} models unloaded.".format(len(models_l))) for loaded_model in models_to_load: @@ -718,6 +837,9 @@ def loaded_models(only_currently_used=False): def cleanup_models_gc(): do_gc = False + + reset_cast_buffers() + for i in range(len(current_loaded_models)): cur = current_loaded_models[i] if cur.is_dead(): @@ -735,6 +857,13 @@ def cleanup_models_gc(): logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) +def archive_model_dtypes(model): + for name, module in model.named_modules(): + for param_name, param in module.named_parameters(recurse=False): + setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + for buf_name, buf in module.named_buffers(recurse=False): + setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype) + def cleanup_models(): to_delete = [] @@ -766,11 +895,14 @@ def unet_offload_device(): return torch.device("cpu") def unet_inital_load_device(parameters, dtype): + cpu_dev = torch.device("cpu") + if comfy.memory_management.aimdo_enabled: + return cpu_dev + torch_dev = get_torch_device() if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: return torch_dev - cpu_dev = torch.device("cpu") if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: return cpu_dev @@ -872,7 +1004,7 @@ def text_encoder_offload_device(): def text_encoder_device(): if args.gpu_only: return get_torch_device() - elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM: + elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled: if should_use_fp16(prioritize_performance=False): return get_torch_device() else: @@ -881,6 +1013,9 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_initial_device(load_device, offload_device, model_size=0): + if comfy.memory_management.aimdo_enabled: + return offload_device + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device @@ -918,6 +1053,12 @@ def intermediate_device(): else: return torch.device("cpu") +def intermediate_dtype(): + if args.fp16_intermediates: + return torch.float16 + else: + return torch.float32 + def vae_device(): if args.cpu_vae: return torch.device("cpu") @@ -1016,8 +1157,8 @@ NUM_STREAMS = 0 if args.async_offload is not None: NUM_STREAMS = args.async_offload else: - # Enable by default on Nvidia - if is_nvidia(): + # Enable by default on Nvidia and AMD + if is_nvidia() or is_amd(): NUM_STREAMS = 2 if args.disable_async_offload: @@ -1037,6 +1178,51 @@ def current_stream(device): return None stream_counters = {} + +STREAM_CAST_BUFFERS = {} +LARGEST_CASTED_WEIGHT = (None, 0) + +def get_cast_buffer(offload_stream, device, size, ref): + global LARGEST_CASTED_WEIGHT + + if offload_stream is not None: + wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) + else: + wf_context = nullcontext() + + cast_buffer = STREAM_CAST_BUFFERS.get(offload_stream, None) + if cast_buffer is None or cast_buffer.numel() < size: + if ref is LARGEST_CASTED_WEIGHT[0]: + #If there is one giant weight we do not want both streams to + #allocate a buffer for it. It's up to the caster to get the other + #offload stream in this corner case + return None + if cast_buffer is not None and cast_buffer.numel() > 50 * (1024 ** 2): + #I want my wrongly sized 50MB+ of VRAM back from the caching allocator right now + synchronize() + del STREAM_CAST_BUFFERS[offload_stream] + del cast_buffer + soft_empty_cache() + with wf_context: + cast_buffer = torch.empty((size), dtype=torch.int8, device=device) + STREAM_CAST_BUFFERS[offload_stream] = cast_buffer + + if size > LARGEST_CASTED_WEIGHT[1]: + LARGEST_CASTED_WEIGHT = (ref, size) + + return cast_buffer + +def reset_cast_buffers(): + global LARGEST_CASTED_WEIGHT + LARGEST_CASTED_WEIGHT = (None, 0) + for offload_stream in STREAM_CAST_BUFFERS: + offload_stream.synchronize() + synchronize() + STREAM_CAST_BUFFERS.clear() + soft_empty_cache() + def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) if NUM_STREAMS == 0: @@ -1079,7 +1265,29 @@ def sync_stream(device, stream): return current_stream(device).wait_stream(stream) -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): + +def cast_to_gathered(tensors, r, non_blocking=False, stream=None): + wf_context = nullcontext() + if stream is not None: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + with wf_context: + for tensor in tensors: + dest_view = dest_views.pop(0) + if tensor is None: + continue + if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): + continue + storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() + if hasattr(storage, "_comfy_tensor_mmap_touched"): + storage._comfy_tensor_mmap_touched = True + dest_view.copy_(tensor, non_blocking=non_blocking) + + +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1098,10 +1306,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) with wf_context: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: - r = torch.empty_like(weight, dtype=dtype, device=device) + if r is None: + r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r @@ -1121,7 +1331,17 @@ if not args.disable_pinned_memory: MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) -PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) +PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) + +def discard_cuda_async_error(): + try: + a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) + _ = a + b + synchronize() + except RuntimeError: + #Dump it! We already know about it from the synchronous return + pass def pin_memory(tensor): global TOTAL_PINNED_MEMORY @@ -1143,7 +1363,7 @@ def pin_memory(tensor): if not tensor.is_contiguous(): return False - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: return False @@ -1155,6 +1375,9 @@ def pin_memory(tensor): PINNED_MEMORY[ptr] = size TOTAL_PINNED_MEMORY += size return True + else: + logging.warning("Pin error.") + discard_cuda_async_error() return False @@ -1167,7 +1390,7 @@ def unpin_memory(tensor): return False ptr = tensor.data_ptr() - size = tensor.numel() * tensor.element_size() + size = tensor.nbytes size_stored = PINNED_MEMORY.get(ptr, None) if size_stored is None: @@ -1183,6 +1406,9 @@ def unpin_memory(tensor): if len(PINNED_MEMORY) == 0: TOTAL_PINNED_MEMORY = 0 return True + else: + logging.warning("Unpin error.") + discard_cuda_async_error() return False @@ -1485,6 +1711,29 @@ def supports_fp8_compute(device=None): return True +def supports_nvfp4_compute(device=None): + if not is_nvidia(): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + +def supports_mxfp8_compute(device=None): + if not is_nvidia(): + return False + + if torch_version_numeric < (2, 10): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): @@ -1506,7 +1755,17 @@ def lora_compute_dtype(device): LORA_COMPUTE_DTYPES[device] = dtype return dtype +def synchronize(): + if cpu_mode(): + return + if is_intel_xpu(): + torch.xpu.synchronize() + elif torch.cuda.is_available(): + torch.cuda.synchronize() + def soft_empty_cache(force=False): + if cpu_mode(): + return global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() @@ -1517,15 +1776,17 @@ def soft_empty_cache(force=False): elif is_mlu(): torch.mlu.empty_cache() elif torch.cuda.is_available(): + torch.cuda.synchronize() torch.cuda.empty_cache() torch.cuda.ipc_collect() def unload_all_models(): free_memory(1e30, get_torch_device()) - -#TODO: might be cleaner to put this somewhere else -import threading +def debug_memory_summary(): + if is_amd() or is_nvidia(): + return torch.cuda.memory.memory_summary() + return "" class InterruptProcessingException(Exception): pass diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d26c690..c26d37db2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -19,7 +19,6 @@ from __future__ import annotations import collections -import copy import inspect import logging import math @@ -38,19 +37,7 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP - -def string_to_seed(data): - crc = 0xFFFFFFFF - for byte in data: - if isinstance(byte, str): - byte = ord(byte) - crc ^= byte - for _ in range(8): - if crc & 1: - crc = (crc >> 1) ^ 0xEDB88320 - else: - crc >>= 1 - return crc ^ 0xFFFFFFFF +import comfy_aimdo.model_vbar def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None): to = model_options["transformer_options"].copy() @@ -123,6 +110,10 @@ def move_weight_functions(m, device): memory += f.move_to(device=device) return memory +def string_to_seed(data): + logging.warning("WARNING: string_to_seed has moved from comfy.model_patcher to comfy.utils") + return comfy.utils.string_to_seed(data) + class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key @@ -169,6 +160,11 @@ def get_key_weight(model, key): return weight, set_func, convert_func +def key_param_name_to_key(key, param): + if len(key) == 0: + return param + return "{}.{}".format(key, param) + class AutoPatcherEjector: def __init__(self, model: 'ModelPatcher', skip_and_inject_on_exit_only=False): self.model = model @@ -212,6 +208,27 @@ class MemoryCounter: def decrement(self, used: int): self.value -= used +CustomTorchDevice = collections.namedtuple("FakeDevice", ["type", "index"])("comfy-lazy-caster", 0) + +class LazyCastingParam(torch.nn.Parameter): + def __new__(cls, model, key, tensor): + return super().__new__(cls, tensor) + + def __init__(self, model, key, tensor): + self.model = model + self.key = key + + @property + def device(self): + return CustomTorchDevice + + #safetensors will .to() us to the cpu which we catch here to cast on demand. The returned tensor is + #then just a short lived thing in the safetensors serialization logic inside its big for loop over + #all weights getting garbage collected per-weight + def to(self, *args, **kwargs): + return self.model.patch_weight_to_device(self.key, device_to=self.model.load_device, return_weight=True).to("cpu") + + class ModelPatcher: def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): self.size = size @@ -224,6 +241,7 @@ class ModelPatcher: self.patches = {} self.backup = {} + self.backup_buffers = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} @@ -254,6 +272,7 @@ class ModelPatcher: self.is_clip = False self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed + self.cached_patcher_init: tuple[Callable, tuple] | None = None if not hasattr(self.model, 'model_loaded_weight_memory'): self.model.model_loaded_weight_memory = 0 @@ -269,12 +288,18 @@ class ModelPatcher: if not hasattr(self.model, 'model_offload_buffer_memory'): self.model.model_offload_buffer_memory = 0 + def is_dynamic(self): + return False + def model_size(self): if self.size > 0: return self.size self.size = comfy.model_management.module_size(self.model) return self.size + def model_mmap_residency(self, free=False): + return comfy.model_management.module_mmap_residency(self.model, free=free) + def get_ram_usage(self): return self.model_size() @@ -284,8 +309,31 @@ class ModelPatcher: def lowvram_patch_counter(self): return self.model.lowvram_patch_counter - def clone(self): - n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) + def get_free_memory(self, device): + #Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In + #the vast majority of setups a little bit of offloading on the giant model more + #than pays for CFG. So return everything both torch and Aimdo could give us + aimdo_mem = 0 + if comfy.memory_management.aimdo_enabled: + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + return comfy.model_management.get_free_memory(device) + aimdo_mem + + def get_clone_model_override(self): + return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) + + def clone(self, disable_dynamic=False, model_override=None): + class_ = self.__class__ + if self.is_dynamic() and disable_dynamic: + class_ = ModelPatcher + if model_override is None: + if self.cached_patcher_init is None: + raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.") + temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True) + model_override = temp_model_patcher.get_clone_model_override() + if model_override is None: + model_override = self.get_clone_model_override() + + n = class_(model_override[0], self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n.patches = {} for k in self.patches: n.patches[k] = self.patches[k][:] @@ -293,14 +341,13 @@ class ModelPatcher: n.object_patches = self.object_patches.copy() n.weight_wrapper_patches = self.weight_wrapper_patches.copy() - n.model_options = copy.deepcopy(self.model_options) - n.backup = self.backup - n.object_patches_backup = self.object_patches_backup + n.model_options = comfy.utils.deepcopy_list_dict(self.model_options) n.parent = self - n.pinned = self.pinned n.force_cast_weights = self.force_cast_weights + n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1] + # attachments n.attachments = {} for k in self.attachments: @@ -339,6 +386,8 @@ class ModelPatcher: n.is_clip = self.is_clip n.hook_mode = self.hook_mode + n.cached_patcher_init = self.cached_patcher_init + for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE): callback(self, n) return n @@ -383,13 +432,16 @@ class ModelPatcher: def memory_required(self, input_shape): return self.model.memory_required(input_shape=input_shape) + def disable_model_cfg1_optimization(self): + self.model_options["disable_cfg1_optimization"] = True + def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way else: self.model_options["sampler_cfg_function"] = sampler_cfg_function if disable_cfg1_optimization: - self.model_options["disable_cfg1_optimization"] = True + self.disable_model_cfg1_optimization() def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization) @@ -550,6 +602,27 @@ class ModelPatcher: return models + def model_patches_call_function(self, function_name="cleanup", arguments={}): + to = self.model_options["transformer_options"] + if "patches" in to: + patches = to["patches"] + for name in patches: + patch_list = patches[name] + for i in range(len(patch_list)): + if hasattr(patch_list[i], function_name): + getattr(patch_list[i], function_name)(**arguments) + if "patches_replace" in to: + patches = to["patches_replace"] + for name in patches: + patch_list = patches[name] + for k in patch_list: + if hasattr(patch_list[k], function_name): + getattr(patch_list[k], function_name)(**arguments) + if "model_function_wrapper" in self.model_options: + wrap_func = self.model_options["model_function_wrapper"] + if hasattr(wrap_func, function_name): + getattr(wrap_func, function_name)(**arguments) + def model_dtype(self): if hasattr(self.model, "get_dtype"): return self.model.get_dtype() @@ -611,14 +684,14 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False): - if key not in self.patches: - return - + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): weight, set_func, convert_func = get_key_weight(self.model, key) + if key not in self.patches: + return weight + inplace_update = self.weight_inplace_update or inplace_update - if key not in self.backup: + if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) @@ -631,13 +704,15 @@ class ModelPatcher: out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) - if inplace_update: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) + if return_weight: + return out_weight + elif inplace_update: comfy.utils.copy_to_param(self.model, key, out_weight) else: comfy.utils.set_attr_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + return set_func(out_weight, inplace_update=inplace_update, seed=comfy.utils.string_to_seed(key), return_weight=return_weight) def pin_weight_to_device(self, key): weight, set_func, convert_func = get_key_weight(self.model, key) @@ -654,18 +729,19 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self): + def _load_list(self, for_dynamic=False, default_device=None): loading = [] for n, m in self.model.named_modules(): - params = [] - skip = False - for name, param in m.named_parameters(recurse=False): - params.append(name) + default = False + params = { name: param for name, param in m.named_parameters(recurse=False) } for name, param in m.named_parameters(recurse=True): if name not in params: - skip = True # skip random weights in non leaf modules + default = True # default random weights in non leaf modules break - if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): + if default and default_device is not None: + for param_name, param in params.items(): + param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None)) + if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): @@ -681,7 +757,13 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - loading.append((module_offload_mem, module_mem, n, m, params)) + # Dynamic: small weights (<64KB) first, then larger weights prioritized by size. + # Non-dynamic: prioritize by module offload cost. + if for_dynamic: + sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem) + else: + sort_criteria = (module_offload_mem,) + loading.append(sort_criteria + (module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -718,6 +800,7 @@ class ModelPatcher: continue cast_weight = self.force_cast_weights + m.comfy_force_cast_weights = self.force_cast_weights if lowvram_weight: if hasattr(m, "comfy_cast_weights"): m.weight_function = [] @@ -772,7 +855,7 @@ class ModelPatcher: continue for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) self.unpin_weight(key) self.patch_weight_to_device(key, device_to=device_to) if comfy.model_management.is_device_cuda(device_to): @@ -788,13 +871,14 @@ class ModelPatcher: n = x[1] params = x[3] for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) + usable_stat = "{:.2f} MB usable,".format(lowvram_model_memory / (1024 * 1024)) if lowvram_model_memory < 1e32 else "" if lowvram_counter > 0: - logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) + logging.info("loaded partially; {} {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(usable_stat, mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter)) self.model.model_lowvram = True else: - logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) + logging.info("loaded completely; {} {:.2f} MB loaded, full load: {}".format(usable_stat, mem_counter / (1024 * 1024), full_load)) self.model.model_lowvram = False if full_load: self.model.to(device_to) @@ -893,7 +977,7 @@ class ModelPatcher: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: move_weight = True for param in params: - key = "{}.{}".format(n, param) + key = key_param_name_to_key(n, param) bk = self.backup.get(key, None) if bk is not None: if not lowvram_possible: @@ -944,7 +1028,7 @@ class ModelPatcher: logging.debug("freed {}".format(n)) for param in params: - self.pin_weight_to_device("{}.{}".format(n, param)) + self.pin_weight_to_device(key_param_name_to_key(n, param)) self.model.model_lowvram = True @@ -982,6 +1066,13 @@ class ModelPatcher: return self.model.model_loaded_weight_memory - current_used + def pinned_memory_size(self): + # Pinned memory pressure tracking is only implemented for DynamicVram loading + return 0 + + def partially_unload_ram(self, ram_to_unload): + pass + def detach(self, unpatch_all=True): self.eject_model() self.model_patches_to(self.offload_device) @@ -999,6 +1090,7 @@ class ModelPatcher: return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) def cleanup(self): + self.model_patches_call_function(function_name="cleanup") self.clean_hooks() if hasattr(self.model, "current_patcher"): self.model.current_patcher = None @@ -1315,10 +1407,10 @@ class ModelPatcher: key, original_weights=original_weights) del original_weights[key] if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=string_to_seed(key)) + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) comfy.utils.copy_to_param(self.model, key, out_weight) else: - set_func(out_weight, inplace_update=True, seed=string_to_seed(key)) + set_func(out_weight, inplace_update=True, seed=comfy.utils.string_to_seed(key)) if self.hook_mode == comfy.hooks.EnumHookMode.MaxSpeed: # TODO: disable caching if not enough system RAM to do so target_device = self.offload_device @@ -1353,7 +1445,291 @@ class ModelPatcher: self.unpatch_hooks() self.clear_cached_hook_weights() + def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + unet_state_dict = self.model.diffusion_model.state_dict() + for k, v in unet_state_dict.items(): + op_keys = k.rsplit('.', 1) + if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]: + continue + try: + op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0]) + except: + continue + if not op or not hasattr(op, "comfy_cast_weights") or \ + (hasattr(op, "comfy_patched_weights") and op.comfy_patched_weights == True): + continue + key = "diffusion_model." + k + unet_state_dict[k] = LazyCastingParam(self, key, comfy.utils.get_attr(self.model, key)) + return self.model.state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + def __del__(self): self.unpin_all_weights() self.detach(unpatch_all=False) +class ModelPatcherDynamic(ModelPatcher): + + def __new__(cls, model=None, load_device=None, offload_device=None, size=0, weight_inplace_update=False): + if load_device is not None and comfy.model_management.is_device_cpu(load_device): + #reroute to default MP for CPUs + return ModelPatcher(model, load_device, offload_device, size, weight_inplace_update) + return super().__new__(cls) + + def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): + super().__init__(model, load_device, offload_device, size, weight_inplace_update) + if not hasattr(self.model, "dynamic_vbars"): + self.model.dynamic_vbars = {} + self.non_dynamic_delegate_model = None + assert load_device is not None + + def is_dynamic(self): + return True + + def _vbar_get(self, create=False): + if self.load_device == torch.device("cpu"): + return None + vbar = self.model.dynamic_vbars.get(self.load_device, None) + if create and vbar is None: + # x10. We dont know what model defined type casts we have in the vbar, but virtual address + # space is pretty free. This will cover someone casting an entire model from FP4 to FP32 + # with some left over. + vbar = comfy_aimdo.model_vbar.ModelVBAR(self.model_size() * 10, self.load_device.index) + self.model.dynamic_vbars[self.load_device] = vbar + return vbar + + def loaded_size(self): + vbar = self._vbar_get() + return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory + + #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. + + def pin_weight_to_device(self, key): + raise RuntimeError("pin_weight_to_device invalid for dymamic weight loading") + + def unpin_weight(self, key): + raise RuntimeError("unpin_weight invalid for dymamic weight loading") + + def unpin_all_weights(self): + self.partially_unload_ram(1e32) + + def memory_required(self, input_shape): + #Pad this significantly. We are trying to get away from precise estimates. This + #estimate is only used when using the ModelPatcherDynamic after ModelPatcher. If you + #use all ModelPatcherDynamic this is ignored and its all done dynamically. + return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) + + + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): + + #Force patching doesn't make sense in Dynamic loading, as you dont know what does and + #doesn't need to be forced at this stage. The only thing you could do would be patch + #it all on CPU which consumes huge RAM. + assert not force_patch_weights + + #Full load doesn't make sense as we dont actually have any loader capability here and + #now. + assert not full_load + + assert device_to == self.load_device + + num_patches = 0 + allocated_size = 0 + self.model.model_loaded_weight_memory = 0 + + with self.use_ejected(): + self.unpatch_hooks() + + vbar = self._vbar_get(create=True) + if vbar is not None: + vbar.prioritize() + + loading = self._load_list(for_dynamic=True, default_device=device_to) + loading.sort() + + for x in loading: + *_, module_mem, n, m, params = x + + def set_dirty(item, dirty): + if dirty or not hasattr(item, "_v_signature"): + item._v_signature = None + + def setup_param(self, m, n, param_key): + nonlocal num_patches + key = key_param_name_to_key(n, param_key) + + weight_function = [] + + weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return (False, 0) + if key in self.patches: + if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: + return (True, 0) + setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + num_patches += 1 + else: + setattr(m, param_key + "_lowvram_function", None) + + if key in self.weight_wrapper_patches: + weight_function.extend(self.weight_wrapper_patches[key]) + setattr(m, param_key + "_function", weight_function) + geometry = weight + if not isinstance(weight, QuantizedTensor): + model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype + weight._model_dtype = model_dtype + geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) + return (False, comfy.memory_management.vram_aligned_size(geometry)) + + def force_load_param(self, param_key, device_to): + key = key_param_name_to_key(n, param_key) + if key in self.backup: + comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) + self.patch_weight_to_device(key, device_to=device_to) + weight, _, _ = get_key_weight(self.model, key) + if weight is not None: + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() + + if hasattr(m, "comfy_cast_weights"): + m.comfy_cast_weights = True + m.pin_failed = False + m.seed_key = n + set_dirty(m, dirty) + + force_load, v_weight_size = setup_param(self, m, n, "weight") + force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") + force_load = force_load or force_load_bias + v_weight_size += v_weight_bias + + if force_load: + logging.info(f"Module {n} has resizing Lora - force loading") + force_load_param(self, "weight", device_to) + force_load_param(self, "bias", device_to) + else: + if vbar is not None and not hasattr(m, "_v"): + m._v = vbar.alloc(v_weight_size) + allocated_size += v_weight_size + + else: + for param in params: + key = key_param_name_to_key(n, param) + weight, _, _ = get_key_weight(self.model, key) + if key not in self.backup: + self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) + model_dtype = getattr(m, param + "_comfy_model_dtype", None) + casted_weight = weight.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_param(self.model, key, casted_weight) + self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size() + + move_weight_functions(m, device_to) + + for key, buf in self.model.named_buffers(recurse=True): + if key not in self.backup_buffers: + self.backup_buffers[key] = buf + module, buf_name = comfy.utils.resolve_attr(self.model, key) + model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None) + casted_buf = buf.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_buffer(self.model, key, casted_buf) + self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() + + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") + + self.model.device = device_to + self.model.current_weight_patches_uuid = self.patches_uuid + + for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): + #These are all super dangerous. Who knows what the custom nodes actually do here... + callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) + + self.apply_hooks(self.forced_hooks, force_apply=True) + + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): + assert not force_patch_weights #See above + assert self.load_device != torch.device("cpu") + + vbar = self._vbar_get() + freed = 0 if vbar is None else vbar.free_memory(memory_to_free) + + if freed < memory_to_free: + for key in list(self.backup.keys()): + bk = self.backup.pop(key) + comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) + freed += self.model.model_loaded_weight_memory + self.model.model_loaded_weight_memory = 0 + + return freed + + def pinned_memory_size(self): + total = 0 + loading = self._load_list(for_dynamic=True) + for x in loading: + _, _, _, _, m, _ = x + pin = comfy.pinned_memory.get_pin(m) + if pin is not None: + total += pin.numel() * pin.element_size() + return total + + def partially_unload_ram(self, ram_to_unload): + loading = self._load_list(for_dynamic=True, default_device=self.offload_device) + for x in loading: + *_, m, _ = x + ram_to_unload -= comfy.pinned_memory.unpin_memory(m) + if ram_to_unload <= 0: + return + + def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): + #This isn't used by the core at all and can only be to load a model out of + #the control of proper model_managment. If you are a custom node author reading + #this, the correct pattern is to call load_models_gpu() to get a proper + #managed load of your model. + assert not load_weights + return super().patch_model(load_weights=load_weights, force_patch_weights=force_patch_weights) + + def unpatch_model(self, device_to=None, unpatch_weights=True): + super().unpatch_model(device_to=None, unpatch_weights=False) + + if unpatch_weights: + self.partially_unload_ram(1e32) + self.partially_unload(None, 1e32) + for m in self.model.modules(): + move_weight_functions(m, device_to) + + def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): + assert not force_patch_weights #See above + with self.use_ejected(skip_and_inject_on_exit_only=True): + dirty = self.model.current_weight_patches_uuid is not None and (self.model.current_weight_patches_uuid != self.patches_uuid) + + self.unpatch_model(self.offload_device, unpatch_weights=False) + self.patch_model(load_weights=False) + + try: + self.load(device_to, dirty=dirty) + except Exception as e: + self.detach() + raise e + #ModelPatcher::partially_load returns a number on what got loaded but + #nothing in core uses this and we have no data in the Dynamic world. Hit + #the custom node devs with a None rather than a 0 that would mislead any + #logic they might have. + return None + + def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter): + assert False #Should be unreachable - we dont ever cache in the new implementation + + def patch_hook_weight_to_device(self, hooks: comfy.hooks.HookGroup, combined_patches: dict, key: str, original_weights: dict, memory_counter: MemoryCounter): + if key not in combined_patches: + return + + raise RuntimeError("Hooks not implemented in ModelPatcherDynamic. Please remove --fast arguments form ComfyUI startup") + + def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None: + pass + + def get_non_dynamic_delegate(self): + model_patcher = self.clone(disable_dynamic=True, model_override=self.non_dynamic_delegate_model) + self.non_dynamic_delegate_model = model_patcher.get_clone_model_override() + return model_patcher + + +CoreModelPatcher = ModelPatcher diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 2a00ed819..13860e6a2 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -83,6 +83,16 @@ class IMG_TO_IMG(X0): def calculate_input(self, sigma, noise): return noise +class IMG_TO_IMG_FLOW(CONST): + def calculate_denoised(self, sigma, model_output, model_input): + return model_output + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + return latent_image + + def inverse_noise_scaling(self, sigma, latent): + return 1.0 - latent + class COSMOS_RFLOW: def calculate_input(self, sigma, noise): sigma = (sigma / (sigma + 1)) diff --git a/comfy/ops.py b/comfy/ops.py index 16889bb82..ca25693db 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -21,8 +21,13 @@ import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float -import comfy.rmsnorm import json +import comfy.memory_management +import comfy.pinned_memory +import comfy.utils + +import comfy_aimdo.model_vbar +import comfy_aimdo.torch def run_every_op(): if torch.compiler.is_compiling(): @@ -48,6 +53,8 @@ try: SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) def scaled_dot_product_attention(q, k, v, *args, **kwargs): + if q.nelement() < 1024 * 128: # arbitrary number, for small inputs cudnn attention seems slower + return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) else: @@ -72,14 +79,142 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): +def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): + + #vbar doesn't support CPU weights, but some custom nodes have weird paths + #that might switch the layer to the CPU and expect it to work. We have to take + #a clone conservatively as we are mmapped and some SFT files are packed misaligned + #If you are a custom node author reading this, please move your layer to the GPU + #or declare your ModelPatcher as CPU in the first place. + if comfy.model_management.is_device_cpu(device): + weight = s.weight.to(dtype=dtype, copy=True) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + bias = None + if s.bias is not None: + bias = s.bias.to(dtype=bias_dtype, copy=True) + return weight, bias, (None, None, None) + + offload_stream = None + xfer_dest = None + + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + if signature is not None: + if resident: + weight = s._v_weight + bias = s._v_bias + else: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + + if not resident: + cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) + cast_dest = None + + xfer_source = [ s.weight, s.bias ] + + pin = comfy.pinned_memory.get_pin(s) + if pin is not None: + xfer_source = [ pin ] + + for data, geometry in zip([ s.weight, s.bias ], cast_geometry): + if data is None: + continue + if data.dtype != geometry.dtype: + cast_dest = xfer_dest + if cast_dest is None: + cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) + xfer_dest = None + break + + dest_size = comfy.memory_management.vram_aligned_size(xfer_source) + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None: + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + offload_stream = comfy.model_management.get_offload_stream(device) + xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s) + if xfer_dest is None: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None + + if signature is None and pin is None: + comfy.pinned_memory.pin_memory(s) + pin = comfy.pinned_memory.get_pin(s) + else: + pin = None + + if pin is not None: + comfy.model_management.cast_to_gathered(xfer_source, pin) + xfer_source = [ pin ] + #send it over + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + comfy.model_management.sync_stream(device, offload_stream) + + if cast_dest is not None: + for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), + comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + if post_cast is not None: + post_cast.copy_(pre_cast) + xfer_dest = cast_dest + + params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + weight = params[0] + bias = params[1] + if signature is not None: + s._v_weight = weight + s._v_bias = bias + s._v_signature=signature + + def post_cast(s, param_key, x, dtype, resident, update_weight): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + fns = getattr(s, param_key + "_function", []) + + orig = x + + def to_dequant(tensor, dtype): + tensor = tensor.to(dtype=dtype) + if isinstance(tensor, QuantizedTensor): + tensor = tensor.dequantize() + return tensor + + if orig.dtype != dtype or len(fns) > 0: + x = to_dequant(x, dtype) + if not resident and lowvram_fn is not None: + x = to_dequant(x, dtype if compute_dtype is None else compute_dtype) + x = lowvram_fn(x) + if (want_requant and len(fns) == 0 or update_weight): + seed = comfy.utils.string_to_seed(s.seed_key) + if isinstance(orig, QuantizedTensor): + y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) + else: + y = comfy.float.stochastic_rounding(x, orig.dtype, seed=seed) + if want_requant and len(fns) == 0: + x = y + if update_weight: + orig.copy_(y) + for f in fns: + x = f(x) + return x + + update_weight = signature is not None + + weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + if s.bias is not None: + bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + + #FIXME: weird offload return protocol + return weight, bias, (offload_stream, device if signature is not None else None, None) + + +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # will add async-offload support to your cast and improve performance. if input is not None: if dtype is None: if isinstance(input, QuantizedTensor): - dtype = input._layout_params["orig_dtype"] + dtype = input.params.orig_dtype else: dtype = input.dtype if bias_dtype is None: @@ -87,22 +222,38 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + non_blocking = comfy.model_management.device_supports_non_blocking(device) + + if hasattr(s, "_v"): + return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): offload_stream = comfy.model_management.get_offload_stream(device) else: offload_stream = None - non_blocking = comfy.model_management.device_supports_non_blocking(device) + bias = None + weight = None + + if offload_stream is not None and not args.cuda_malloc: + cast_buffer_size = comfy.memory_management.vram_aligned_size([ s.weight, s.bias ]) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + #The streams can be uneven in buffer capability and reject us. Retry to get the other stream + if cast_buffer is None: + offload_stream = comfy.model_management.get_offload_stream(device) + cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s) + params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer) + weight = params[0] + bias = params[1] weight_has_function = len(s.weight_function) > 0 bias_has_function = len(s.bias_function) > 0 - weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream, r=weight) - bias = None if s.bias is not None: - bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) + bias = comfy.model_management.cast_to(s.bias, None, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream, r=bias) comfy.model_management.sync_stream(device, offload_stream) @@ -110,6 +261,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of weight_a = weight if s.bias is not None: + bias = bias.to(dtype=bias_dtype) for f in s.bias_function: bias = f(bias) @@ -131,14 +283,20 @@ def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return os, weight_a, bias_a = offload_stream + device=None + #FIXME: This is really bad RTTI + if weight_a is not None and not isinstance(weight_a, torch.Tensor): + comfy_aimdo.model_vbar.vbar_unpin(s._v) + device = weight_a if os is None: return - if weight_a is not None: - device = weight_a.device - else: - if bias_a is None: - return - device = bias_a.device + if device is None: + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device os.wait_stream(comfy.model_management.current_stream(device)) @@ -148,7 +306,77 @@ class CastWeightBiasOp: bias_function = [] class disable_weight_init: + @staticmethod + def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata, + missing_keys, unexpected_keys, weight_shape, + bias_shape=None): + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + prefix_len = len(prefix) + for k, v in state_dict.items(): + key = k[prefix_len:] + if key == "weight": + if not assign_to_params_buffers: + v = v.clone() + module.weight = torch.nn.Parameter(v, requires_grad=False) + elif bias_shape is not None and key == "bias" and v is not None: + if not assign_to_params_buffers: + v = v.clone() + module.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + if module.weight is None: + module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False) + missing_keys.append(prefix + "weight") + + if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False): + module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False) + missing_keys.append(prefix + "bias") + class Linear(torch.nn.Linear, CastWeightBiasOp): + + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + # don't trust subclasses that BYO state dict loader to call us. + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict): + super().__init__(in_features, out_features, bias, device, dtype) + return + + # Issue is with `torch.empty` still reserving the full memory for the layer. + # Windows doesn't over-commit memory so without this, We are momentarily commit + # charged for the weight even though we might zero-copy it when we load the + # state dict. If the commit charge exceeds the ceiling we can destabilize the + # system. + torch.nn.Module.__init__(self) + self.in_features = in_features + self.out_features = out_features + self.weight = None + self.bias = None + self.comfy_need_lazy_init_bias=bias + self.weight_comfy_model_dtype = dtype + self.bias_comfy_model_dtype = dtype + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + disable_weight_init._lazy_load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + missing_keys, + unexpected_keys, + weight_shape=(self.in_features, self.out_features), + bias_shape=(self.out_features,), + ) + + def reset_parameters(self): return None @@ -203,7 +431,9 @@ class disable_weight_init: def reset_parameters(self): return None - def _conv_forward(self, input, weight, bias, *args, **kwargs): + def _conv_forward(self, input, weight, bias, autopad=None, *args, **kwargs): + if autopad == "causal_zero": + weight = weight[:, :, -input.shape[2]:, :, :] if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16): out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True) if bias is not None: @@ -212,15 +442,15 @@ class disable_weight_init: else: return super()._conv_forward(input, weight, bias, *args, **kwargs) - def forward_comfy_cast_weights(self, input): + def forward_comfy_cast_weights(self, input, autopad=None): weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - x = self._conv_forward(input, weight, bias) + x = self._conv_forward(input, weight, bias, autopad=autopad) uncast_bias_weight(self, weight, bias, offload_stream) return x def forward(self, *args, **kwargs): run_every_op() - if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0 or "autopad" in kwargs: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -264,7 +494,7 @@ class disable_weight_init: else: return super().forward(*args, **kwargs) - class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): + class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp): def reset_parameters(self): self.bias = None return None @@ -276,8 +506,7 @@ class disable_weight_init: weight = None bias = None offload_stream = None - x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated - # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -337,6 +566,53 @@ class disable_weight_init: return super().forward(*args, **kwargs) class Embedding(torch.nn.Embedding, CastWeightBiasOp): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, + norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, + _freeze=False, device=None, dtype=None): + # don't trust subclasses that BYO state dict loader to call us. + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict): + super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, + norm_type, scale_grad_by_freq, sparse, _weight, + _freeze, device, dtype) + return + + torch.nn.Module.__init__(self) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + # Keep shape/dtype visible for module introspection without reserving storage. + embedding_dtype = dtype if dtype is not None else torch.get_default_dtype() + self.weight = torch.nn.Parameter( + torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype), + requires_grad=False, + ) + self.bias = None + self.weight_comfy_model_dtype = dtype + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + disable_weight_init._lazy_load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + missing_keys, + unexpected_keys, + weight_shape=(self.num_embeddings, self.embedding_dim), + ) + def reset_parameters(self): self.bias = None return None @@ -412,26 +688,35 @@ def fp8_linear(self, input): return None input_dtype = input.dtype + input_shape = input.shape + tensor_3d = input.ndim == 3 - if input.ndim == 3 or input.ndim == 2: - w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) + if tensor_3d: + input = input.reshape(-1, input_shape[2]) - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) + if input.ndim != 2: + return None + lora_compute_dtype=comfy.model_management.lora_compute_dtype(input.device) + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True, compute_dtype=lora_compute_dtype, want_requant=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - # Wrap weight in QuantizedTensor - this enables unified dispatch - # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! - layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} - quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) - o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + input_fp8 = input.to(dtype).contiguous() + layout_params_input = TensorCoreFP8Layout.Params(scale=scale_input, orig_dtype=input_dtype, orig_shape=tuple(input_fp8.shape)) + quantized_input = QuantizedTensor(input_fp8, "TensorCoreFP8Layout", layout_params_input) - uncast_bias_weight(self, w, bias, offload_stream) - return o + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = TensorCoreFP8Layout.Params(scale=scale_weight, orig_dtype=input_dtype, orig_shape=tuple(w.shape)) + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - return None + uncast_bias_weight(self, w, bias, offload_stream) + if tensor_3d: + o = o.reshape((input_shape[0], input_shape[1], w.shape[0])) + + return o class fp8_ops(manual_cast): class Linear(manual_cast.Linear): @@ -456,35 +741,145 @@ class fp8_ops(manual_cast): CUBLAS_IS_AVAILABLE = False try: - from cublas_ops import CublasLinear + from cublas_ops import CublasLinear, cublas_half_matmul CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: - class cublas_ops(disable_weight_init): - class Linear(CublasLinear, disable_weight_init.Linear): + class cublas_ops(manual_cast): + class Linear(CublasLinear, manual_cast.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - return super().forward(input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) # ============================================================================== # Mixed Precision Operations # ============================================================================== -from .quant_ops import QuantizedTensor, QUANT_ALGOS +from .quant_ops import ( + QuantizedTensor, + QUANT_ALGOS, + TensorCoreFP8Layout, + get_layout_class, +) -def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +class QuantLinearFunc(torch.autograd.Function): + """Custom autograd function for quantized linear: quantized forward, optionally FP8 backward. + + When training_fp8_bwd is enabled: + - Forward: quantize input per layout (FP8/NVFP4), use quantized matmul + - Backward: all matmuls use FP8 tensor cores via torch.mm dispatch + - Cached input is FP8 (half the memory of bf16) + + When training_fp8_bwd is disabled: + - Forward: quantize input per layout, use quantized matmul + - Backward: dequantize weight to compute_dtype, use standard matmul + """ + + @staticmethod + def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype): + input_shape = input_float.shape + inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D + + # Quantize input for forward (same layout as weight) + if layout_type is not None: + q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) + else: + q_input = inp + + w = weight.detach() if weight.requires_grad else weight + b = bias.detach() if bias is not None and bias.requires_grad else bias + + output = torch.nn.functional.linear(q_input, w, b) + + # Unflatten output to match original input shape + if len(input_shape) > 2: + output = output.unflatten(0, input_shape[:-1]) + + # Save for backward + ctx.input_shape = input_shape + ctx.has_bias = bias is not None + ctx.compute_dtype = compute_dtype + ctx.weight_requires_grad = weight.requires_grad + ctx.fp8_bwd = comfy.model_management.training_fp8_bwd + + if ctx.fp8_bwd: + # Cache FP8 quantized input — half the memory of bf16 + if isinstance(q_input, QuantizedTensor) and layout_type.startswith('TensorCoreFP8'): + ctx.q_input = q_input # already FP8, reuse + else: + # NVFP4 or other layout — quantize input to FP8 for backward + ctx.q_input = QuantizedTensor.from_float(inp, "TensorCoreFP8E4M3Layout") + ctx.save_for_backward(weight) + else: + ctx.q_input = None + ctx.save_for_backward(input_float, weight) + + return output + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output): + compute_dtype = ctx.compute_dtype + grad_2d = grad_output.flatten(0, -2).to(compute_dtype) + + # Value casting — only difference between fp8 and non-fp8 paths + if ctx.fp8_bwd: + weight, = ctx.saved_tensors + # Wrap as FP8 QuantizedTensors → torch.mm dispatches to _scaled_mm + grad_mm = QuantizedTensor.from_float(grad_2d, "TensorCoreFP8E5M2Layout") + if isinstance(weight, QuantizedTensor) and weight._layout_cls.startswith("TensorCoreFP8"): + weight_mm = weight + elif isinstance(weight, QuantizedTensor): + weight_mm = QuantizedTensor.from_float(weight.dequantize().to(compute_dtype), "TensorCoreFP8E4M3Layout") + else: + weight_mm = QuantizedTensor.from_float(weight.to(compute_dtype), "TensorCoreFP8E4M3Layout") + input_mm = ctx.q_input + else: + input_float, weight = ctx.saved_tensors + # Standard tensors → torch.mm does regular matmul + grad_mm = grad_2d + if isinstance(weight, QuantizedTensor): + weight_mm = weight.dequantize().to(compute_dtype) + else: + weight_mm = weight.to(compute_dtype) + input_mm = input_float.flatten(0, -2).to(compute_dtype) if ctx.weight_requires_grad else None + + # Computation — same for both paths, dispatch handles the rest + grad_input = torch.mm(grad_mm, weight_mm) + if len(ctx.input_shape) > 2: + grad_input = grad_input.unflatten(0, ctx.input_shape[:-1]) + + grad_weight = None + if ctx.weight_requires_grad: + grad_weight = torch.mm(grad_mm.t(), input_mm) + + grad_bias = None + if ctx.has_bias: + grad_bias = grad_2d.sum(dim=0) + + return grad_input, grad_weight, grad_bias, None, None, None + + +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm + _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): def __init__( @@ -497,21 +892,33 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec ) -> None: super().__init__() - if dtype is None: - dtype = MixedPrecisionOps._compute_dtype - - self.factory_kwargs = {"device": device, "dtype": dtype} + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features - self._has_bias = bias + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) self.tensor_class = None self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False def reset_parameters(self): return None + def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): + key = f"{prefix}{param_name}" + value = state_dict.pop(key, None) + if value is not None: + value = value.to(device=device) + if dtype is not None: + value = value.view(dtype=dtype) + manually_loaded_keys.append(key) + return value + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -520,7 +927,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec weight_key = f"{prefix}weight" weight = state_dict.pop(weight_key, None) if weight is None: - raise ValueError(f"Missing weight for layer {layer_name}") + logging.warning(f"Missing weight for layer {layer_name}") + return manually_loaded_keys = [weight_key] @@ -529,49 +937,77 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec layer_conf = json.loads(layer_conf.numpy().tobytes()) if layer_conf is None: - dtype = self.factory_kwargs["dtype"] - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False) - if dtype != MixedPrecisionOps._compute_dtype: - self.comfy_cast_weights = True - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype)) - else: - self.register_parameter("bias", None) + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: self.quant_format = layer_conf.get("format", None) + self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) if not self._full_precision_mm: - self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + self._full_precision_mm = self._full_precision_mm_config + + if self.quant_format in MixedPrecisionOps._disabled: + self._full_precision_mm = True if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(self.layout_type) - weight_scale_key = f"{prefix}weight_scale" - scale = state_dict.pop(weight_scale_key, None) - if scale is not None: - scale = scale.to(device) - layout_params = { - 'scale': scale, - 'orig_dtype': MixedPrecisionOps._compute_dtype, - 'block_size': qconfig.get("group_size", None), - } + # Load format-specific parameters + if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: + # FP8: single tensor scale + scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - if scale is not None: - manually_loaded_keys.append(weight_scale_key) + params = layout_cls.Params( + scale=scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + + elif self.quant_format == "mxfp8": + # MXFP8: E8M0 block scales stored as uint8 in safetensors + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.uint8) + + if block_scale is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + + block_scale = block_scale.view(torch.float8_e8m0fnu) + + params = layout_cls.Params( + scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + + elif self.quant_format == "nvfp4": + # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) + tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.float8_e4m3fn) + + if tensor_scale is None or block_scale is None: + raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") + + params = layout_cls.Params( + scale=tensor_scale, + block_scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + else: + raise ValueError(f"Unsupported quantization format: {self.quant_format}") self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), requires_grad=False ) - if self._has_bias: - self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype)) - else: - self.register_parameter("bias", None) - for param_name in qconfig["parameters"]: + if param_name in {"weight_scale", "weight_scale_2"}: + continue # Already handled above + param_key = f"{prefix}{param_name}" _v = state_dict.pop(param_key, None) if _v is None: @@ -586,20 +1022,40 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec missing_keys.remove(key) def state_dict(self, *args, destination=None, prefix="", **kwargs): - sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) + if destination is not None: + sd = destination + else: + sd = {} + + if not hasattr(self, 'weight'): + logging.warning("Warning: state dict on uninitialized op {}".format(prefix)) + return sd + + if self.bias is not None: + sd["{}bias".format(prefix)] = self.bias + if isinstance(self.weight, QuantizedTensor): - sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] + quant_conf = {"format": self.quant_format} - if self._full_precision_mm: + if self._full_precision_mm_config: quant_conf["full_precision_matrix_mult"] = True sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + + input_scale = getattr(self, 'input_scale', None) + if input_scale is not None: + sd["{}input_scale".format(prefix)] = input_scale + else: + sd["{}weight".format(prefix)] = self.weight return sd def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) - def forward_comfy_cast_weights(self, input): - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant) x = self._forward(input, weight, bias) uncast_bias_weight(self, weight, bias, offload_stream) return x @@ -607,12 +1063,62 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def forward(self, input, *args, **kwargs): run_every_op() - if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: - return self.forward_comfy_cast_weights(input, *args, **kwargs) - if (getattr(self, 'layout_type', None) is not None and - not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) - return self._forward(input, self.weight, self.bias) + input_shape = input.shape + reshaped_3d = False + #If cast needs to apply lora, it should be done in the compute dtype + compute_dtype = input.dtype + + _use_quantized = ( + getattr(self, 'layout_type', None) is not None and + not isinstance(input, QuantizedTensor) and not self._full_precision_mm and + not getattr(self, 'comfy_force_cast_weights', False) and + len(self.weight_function) == 0 and len(self.bias_function) == 0 + ) + + # Training path: quantized forward with compute_dtype backward via autograd function + if (input.requires_grad and _use_quantized): + + weight, bias, offload_stream = cast_bias_weight( + self, + input, + offloadable=True, + compute_dtype=compute_dtype, + want_requant=True + ) + + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + + output = QuantLinearFunc.apply( + input, weight, bias, self.layout_type, scale, compute_dtype + ) + + uncast_bias_weight(self, weight, bias, offload_stream) + return output + + # Inference path (unchanged) + if _use_quantized: + + # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) + input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input + + # Fall back to non-quantized for non-2D tensors + if input_reshaped.ndim == 2: + reshaped_3d = input.ndim == 3 + # dtype is now implicit in the layout class + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) + + output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor)) + + # Reshape output back to 3D if input was 3D + if reshaped_3d: + output = output.reshape((input_shape[0], input_shape[1], self.weight.shape[0])) + + return output def convert_weight(self, weight, inplace=False, **kwargs): if isinstance(weight, QuantizedTensor): @@ -622,7 +1128,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + # dtype is now implicit in the layout class + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", stochastic_rounding=seed, inplace_ops=True).to(self.weight.dtype) else: weight = weight.to(self.weight.dtype) if return_weight: @@ -639,7 +1146,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec for key, param in self._parameters.items(): if param is None: continue - self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + p = fn(param) + if p.is_inference(): + p = p.clone() + self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) for key, buf in self._buffers.items(): if buf is not None: self._buffers[key] = fn(buf) @@ -649,10 +1159,20 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) + mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") - return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) + disabled = set() + if not nvfp4_compute: + disabled.add("nvfp4") + if not mxfp8_compute: + disabled.add("mxfp8") + if not fp8_compute: + disabled.add("float8_e4m3fn") + disabled.add("float8_e5m2") + return mixed_precision_ops(model_config.quant_config, compute_dtype, disabled=disabled) if ( fp8_compute and diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py new file mode 100644 index 000000000..f6fb806c4 --- /dev/null +++ b/comfy/pinned_memory.py @@ -0,0 +1,43 @@ +import comfy.model_management +import comfy.memory_management +import comfy_aimdo.host_buffer +import comfy_aimdo.torch + +from comfy.cli_args import args + +def get_pin(module): + return getattr(module, "_pin", None) + +def pin_memory(module): + if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + return + #FIXME: This is a RAM cache trigger event + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + + if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: + module.pin_failed = True + return False + + try: + hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) + except RuntimeError: + module.pin_failed = True + return False + + module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) + module._pin_hostbuf = hostbuf + comfy.model_management.TOTAL_PINNED_MEMORY += size + return True + +def unpin_memory(module): + if get_pin(module) is None: + return 0 + size = module._pin.numel() * module._pin.element_size() + + comfy.model_management.TOTAL_PINNED_MEMORY -= size + if comfy.model_management.TOTAL_PINNED_MEMORY < 0: + comfy.model_management.TOTAL_PINNED_MEMORY = 0 + + del module._pin + del module._pin_hostbuf + return size diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd96541d7..42ee08fb2 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -1,580 +1,221 @@ import torch import logging -from typing import Tuple, Dict + +try: + import comfy_kitchen as ck + from comfy_kitchen.tensor import ( + QuantizedTensor, + QuantizedLayout, + TensorCoreFP8Layout as _CKFp8Layout, + TensorCoreNVFP4Layout as _CKNvfp4Layout, + register_layout_op, + register_layout_class, + get_layout_class, + ) + _CK_AVAILABLE = True + if torch.version.cuda is None: + ck.registry.disable("cuda") + else: + cuda_version = tuple(map(int, str(torch.version.cuda).split('.'))) + if cuda_version < (13,): + ck.registry.disable("cuda") + logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") + + ck.registry.disable("triton") + for k, v in ck.list_backends().items(): + logging.info(f"Found comfy_kitchen backend {k}: {v}") +except ImportError as e: + logging.error(f"Failed to import comfy_kitchen, Error: {e}, fp8 and fp4 support will not be available.") + _CK_AVAILABLE = False + + class QuantizedTensor: + pass + + class _CKFp8Layout: + pass + + class _CKNvfp4Layout: + pass + + def register_layout_class(name, cls): + pass + + def get_layout_class(name): + return None + +_CK_MXFP8_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout + _CK_MXFP8_AVAILABLE = True + except ImportError: + logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.") + +if not _CK_MXFP8_AVAILABLE: + class _CKMxfp8Layout: + pass + import comfy.float -_LAYOUT_REGISTRY = {} -_GENERIC_UTILS = {} - - -def register_layout_op(torch_op, layout_type): - """ - Decorator to register a layout-specific operation handler. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) - layout_type: Layout class (e.g., TensorCoreFP8Layout) - Example: - @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) - def fp8_linear(func, args, kwargs): - # FP8-specific linear implementation - ... - """ - def decorator(handler_func): - if torch_op not in _LAYOUT_REGISTRY: - _LAYOUT_REGISTRY[torch_op] = {} - _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func - return handler_func - return decorator - - -def register_generic_util(torch_op): - """ - Decorator to register a generic utility that works for all layouts. - Args: - torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) - - Example: - @register_generic_util(torch.ops.aten.detach.default) - def generic_detach(func, args, kwargs): - # Works for any layout - ... - """ - def decorator(handler_func): - _GENERIC_UTILS[torch_op] = handler_func - return handler_func - return decorator - - -def _get_layout_from_args(args): - for arg in args: - if isinstance(arg, QuantizedTensor): - return arg._layout_type - elif isinstance(arg, (list, tuple)): - for item in arg: - if isinstance(item, QuantizedTensor): - return item._layout_type - return None - - -def _move_layout_params_to_device(params, device): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.to(device=device) - else: - new_params[k] = v - return new_params - - -def _copy_layout_params(params): - new_params = {} - for k, v in params.items(): - if isinstance(v, torch.Tensor): - new_params[k] = v.clone() - else: - new_params[k] = v - return new_params - -def _copy_layout_params_inplace(src, dst, non_blocking=False): - for k, v in src.items(): - if isinstance(v, torch.Tensor): - dst[k].copy_(v, non_blocking=non_blocking) - else: - dst[k] = v - -class QuantizedLayout: - """ - Base class for quantization layouts. - - A layout encapsulates the format-specific logic for quantization/dequantization - and provides a uniform interface for extracting raw tensors needed for computation. - - New quantization formats should subclass this and implement the required methods. - """ - @classmethod - def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: - raise NotImplementedError(f"{cls.__name__} must implement quantize()") - - @staticmethod - def dequantize(qdata, **layout_params) -> torch.Tensor: - raise NotImplementedError("TensorLayout must implement dequantize()") - - @classmethod - def get_plain_tensors(cls, qtensor) -> torch.Tensor: - raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") - - -class QuantizedTensor(torch.Tensor): - """ - Universal quantized tensor that works with any layout. - - This tensor subclass uses a pluggable layout system to support multiple - quantization formats (FP8, INT4, INT8, etc.) without code duplication. - - The layout_type determines format-specific behavior, while common operations - (detach, clone, to) are handled generically. - - Attributes: - _qdata: The quantized tensor data - _layout_type: Layout class (e.g., TensorCoreFP8Layout) - _layout_params: Dict with layout-specific params (scale, zero_point, etc.) - """ - - @staticmethod - def __new__(cls, qdata, layout_type, layout_params): - """ - Create a quantized tensor. - - Args: - qdata: The quantized data tensor - layout_type: Layout class (subclass of QuantizedLayout) - layout_params: Dict with layout-specific parameters - """ - return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) - - def __init__(self, qdata, layout_type, layout_params): - self._qdata = qdata - self._layout_type = layout_type - self._layout_params = layout_params - - def __repr__(self): - layout_name = self._layout_type - param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) - return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" - - @property - def layout_type(self): - return self._layout_type - - def __tensor_flatten__(self): - """ - Tensor flattening protocol for proper device movement. - """ - inner_tensors = ["_qdata"] - ctx = { - "layout_type": self._layout_type, - } - - tensor_params = {} - non_tensor_params = {} - for k, v in self._layout_params.items(): - if isinstance(v, torch.Tensor): - tensor_params[k] = v - else: - non_tensor_params[k] = v - - ctx["tensor_param_keys"] = list(tensor_params.keys()) - ctx["non_tensor_params"] = non_tensor_params - - for k, v in tensor_params.items(): - attr_name = f"_layout_param_{k}" - object.__setattr__(self, attr_name, v) - inner_tensors.append(attr_name) - - return inner_tensors, ctx - - @staticmethod - def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): - """ - Tensor unflattening protocol for proper device movement. - Reconstructs the QuantizedTensor after device movement. - """ - layout_type = ctx["layout_type"] - layout_params = dict(ctx["non_tensor_params"]) - - for key in ctx["tensor_param_keys"]: - attr_name = f"_layout_param_{key}" - layout_params[key] = inner_tensors[attr_name] - - return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) - - @classmethod - def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': - qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) - return cls(qdata, layout_type, layout_params) - - def dequantize(self) -> torch.Tensor: - return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - kwargs = kwargs or {} - - # Step 1: Check generic utilities first (detach, clone, to, etc.) - if func in _GENERIC_UTILS: - return _GENERIC_UTILS[func](func, args, kwargs) - - # Step 2: Check layout-specific handlers (linear, matmul, etc.) - layout_type = _get_layout_from_args(args) - if layout_type and func in _LAYOUT_REGISTRY: - handler = _LAYOUT_REGISTRY[func].get(layout_type) - if handler: - return handler(func, args, kwargs) - - # Step 3: Fallback to dequantization - if isinstance(args[0] if args else None, QuantizedTensor): - logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") - return cls._dequant_and_fallback(func, args, kwargs) - - @classmethod - def _dequant_and_fallback(cls, func, args, kwargs): - def dequant_arg(arg): - if isinstance(arg, QuantizedTensor): - return arg.dequantize() - elif isinstance(arg, (list, tuple)): - return type(arg)(dequant_arg(a) for a in arg) - return arg - - new_args = dequant_arg(args) - new_kwargs = dequant_arg(kwargs) - return func(*new_args, **new_kwargs) - - def data_ptr(self): - return self._qdata.data_ptr() - - def is_pinned(self): - return self._qdata.is_pinned() - - def is_contiguous(self, *arg, **kwargs): - return self._qdata.is_contiguous(*arg, **kwargs) - - def storage(self): - return self._qdata.storage() - # ============================================================================== -# Generic Utilities (Layout-Agnostic Operations) +# FP8 Layouts with Comfy-Specific Extensions # ============================================================================== -def _create_transformed_qtensor(qt, transform_fn): - new_data = transform_fn(qt._qdata) - new_params = _copy_layout_params(qt._layout_params) - return QuantizedTensor(new_data, qt._layout_type, new_params) +class _TensorCoreFP8LayoutBase(_CKFp8Layout): + FP8_DTYPE = None # Must be overridden in subclass - -def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): - if target_layout is not None and target_layout != torch.strided: - logging.warning( - f"QuantizedTensor: layout change requested to {target_layout}, " - f"but not supported. Ignoring layout." - ) - - # Handle device transfer - current_device = qt._qdata.device - if target_device is not None: - # Normalize device for comparison - if isinstance(target_device, str): - target_device = torch.device(target_device) - if isinstance(current_device, str): - current_device = torch.device(current_device) - - if target_device != current_device: - logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") - new_q_data = qt._qdata.to(device=target_device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - if target_dtype is not None: - new_params["orig_dtype"] = target_dtype - new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) - logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") - return new_qt - - logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") - return qt - - -@register_generic_util(torch.ops.aten.detach.default) -def generic_detach(func, args, kwargs): - """Detach operation - creates a detached copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.detach()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.clone.default) -def generic_clone(func, args, kwargs): - """Clone operation - creates a deep copy of the quantized tensor.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _create_transformed_qtensor(qt, lambda x: x.clone()) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._to_copy.default) -def generic_to_copy(func, args, kwargs): - """Device/dtype transfer operation - handles .to(device) calls.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - op_name="_to_copy" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype_layout) -def generic_to_dtype_layout(func, args, kwargs): - """Handle .to(device) calls using the dtype_layout variant.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - return _handle_device_transfer( - qt, - target_device=kwargs.get('device', None), - target_dtype=kwargs.get('dtype', None), - target_layout=kwargs.get('layout', None), - op_name="to" - ) - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.copy_.default) -def generic_copy_(func, args, kwargs): - qt_dest = args[0] - src = args[1] - non_blocking = args[2] if len(args) > 2 else False - if isinstance(qt_dest, QuantizedTensor): - if isinstance(src, QuantizedTensor): - # Copy from another quantized tensor - qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) - qt_dest._layout_type = src._layout_type - orig_dtype = qt_dest._layout_params["orig_dtype"] - _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) - qt_dest._layout_params["orig_dtype"] = orig_dtype - else: - # Copy from regular tensor - just copy raw data - qt_dest._qdata.copy_(src) - return qt_dest - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten.to.dtype) -def generic_to_dtype(func, args, kwargs): - """Handle .to(dtype) calls - dtype conversion only.""" - src = args[0] - if isinstance(src, QuantizedTensor): - # For dtype-only conversion, just change the orig_dtype, no real cast is needed - target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') - src._layout_params["orig_dtype"] = target_dtype - return src - return func(*args, **kwargs) - - -@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) -def generic_has_compatible_shallow_copy_type(func, args, kwargs): - return True - - -@register_generic_util(torch.ops.aten.empty_like.default) -def generic_empty_like(func, args, kwargs): - """Empty_like operation - creates an empty tensor with the same quantized structure.""" - qt = args[0] - if isinstance(qt, QuantizedTensor): - # Create empty tensor with same shape and dtype as the quantized data - hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"]) - new_qdata = torch.empty_like(qt._qdata, **kwargs) - - # Handle device transfer for layout params - target_device = kwargs.get('device', new_qdata.device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) - - # Update orig_dtype if dtype is specified - new_params['orig_dtype'] = hp_dtype - - return QuantizedTensor(new_qdata, qt._layout_type, new_params) - return func(*args, **kwargs) - -# ============================================================================== -# FP8 Layout + Operation Handlers -# ============================================================================== -class TensorCoreFP8Layout(QuantizedLayout): - """ - Storage format: - - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) - - scale: Scalar tensor (float32) for dequantization - - orig_dtype: Original dtype before quantization (for casting back) - """ @classmethod - def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if cls.FP8_DTYPE is None: + raise NotImplementedError(f"{cls.__name__} must define FP8_DTYPE") + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) if isinstance(scale, str) and scale == "recalculate": - scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max + scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(cls.FP8_DTYPE).max if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small tensor_info = torch.finfo(tensor.dtype) scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max)) - if scale is not None: - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is None: + scale = torch.ones((), device=tensor.device, dtype=torch.float32) + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale, device=tensor.device, dtype=torch.float32) + if stochastic_rounding > 0: if inplace_ops: tensor *= (1.0 / scale).to(tensor.dtype) else: tensor = tensor * (1.0 / scale).to(tensor.dtype) + qdata = comfy.float.stochastic_rounding(tensor, dtype=cls.FP8_DTYPE, seed=stochastic_rounding) else: - scale = torch.ones((), device=tensor.device, dtype=torch.float32) + qdata = ck.quantize_per_tensor_fp8(tensor, scale, cls.FP8_DTYPE) + + params = cls.Params(scale=scale.float(), orig_dtype=orig_dtype, orig_shape=orig_shape) + return qdata, params + + +class TensorCoreMXFP8Layout(_CKMxfp8Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape if stochastic_rounding > 0: - tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) + qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding) else: - lp_amax = torch.finfo(dtype).max - torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) - tensor = tensor.to(dtype, memory_format=torch.contiguous_format) + qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding) - layout_params = { - 'scale': scale, - 'orig_dtype': orig_dtype - } - return tensor, layout_params + params = cls.Params( + scale=block_scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + ) + return qdata, params - @staticmethod - def dequantize(qdata, scale, orig_dtype, **kwargs): - plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) - plain_tensor.mul_(scale) - return plain_tensor +class TensorCoreNVFP4Layout(_CKNvfp4Layout): @classmethod - def get_plain_tensors(cls, qtensor): - return qtensor._qdata, qtensor._layout_params['scale'] + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + if scale is None or (isinstance(scale, str) and scale == "recalculate"): + scale = torch.amax(tensor.abs()) / (ck.float_utils.F8_E4M3_MAX * ck.float_utils.F4_E2M1_MAX) + + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_nvfp4_by_block(tensor, scale, pad_16x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_nvfp4(tensor, scale, pad_16x=needs_padding) + + params = cls.Params( + scale=scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + block_scale=block_scale, + ) + return qdata, params + + +class TensorCoreFP8E4M3Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e4m3fn + + +class TensorCoreFP8E5M2Layout(_TensorCoreFP8LayoutBase): + FP8_DTYPE = torch.float8_e5m2 + + +# Backward compatibility alias - default to E4M3 +TensorCoreFP8Layout = TensorCoreFP8E4M3Layout + + +# ============================================================================== +# Registry +# ============================================================================== + +register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) +register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) +register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) +register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) +if _CK_MXFP8_AVAILABLE: + register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) QUANT_ALGOS = { "float8_e4m3fn": { "storage_t": torch.float8_e4m3fn, "parameters": {"weight_scale", "input_scale"}, - "comfy_tensor_layout": "TensorCoreFP8Layout", + "comfy_tensor_layout": "TensorCoreFP8E4M3Layout", + }, + "float8_e5m2": { + "storage_t": torch.float8_e5m2, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreFP8E5M2Layout", + }, + "nvfp4": { + "storage_t": torch.uint8, + "parameters": {"weight_scale", "weight_scale_2", "input_scale"}, + "comfy_tensor_layout": "TensorCoreNVFP4Layout", + "group_size": 16, }, } -LAYOUTS = { - "TensorCoreFP8Layout": TensorCoreFP8Layout, -} +if _CK_MXFP8_AVAILABLE: + QUANT_ALGOS["mxfp8"] = { + "storage_t": torch.float8_e4m3fn, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreMXFP8Layout", + "group_size": 32, + } -@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") -def fp8_linear(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None +# ============================================================================== +# Re-exports for backward compatibility +# ============================================================================== - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - out_dtype = kwargs.get("out_dtype") - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - weight_t = plain_weight.t() - - tensor_2d = False - if len(plain_input.shape) == 2: - tensor_2d = True - plain_input = plain_input.unsqueeze(1) - - input_shape = plain_input.shape - if len(input_shape) != 3: - return None - - try: - output = torch._scaled_mm( - plain_input.reshape(-1, input_shape[2]).contiguous(), - weight_t, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - - if not tensor_2d: - output = output.reshape((-1, input_shape[1], weight.shape[0])) - - if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: - output_scale = scale_a * scale_b - output_params = { - 'scale': output_scale, - 'orig_dtype': input_tensor._layout_params['orig_dtype'] - } - return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) - else: - return output - - except Exception as e: - raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") - - # Case 2: DQ Fallback - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - if isinstance(input_tensor, QuantizedTensor): - input_tensor = input_tensor.dequantize() - - return torch.nn.functional.linear(input_tensor, weight, bias) - -def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): - if out_dtype is None: - out_dtype = input_tensor._layout_params['orig_dtype'] - - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) - - output = torch._scaled_mm( - plain_input.contiguous(), - plain_weight, - bias=bias, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - ) - - if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 - output = output[0] - return output - -@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") -def fp8_addmm(func, args, kwargs): - input_tensor = args[1] - weight = args[2] - bias = args[0] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - if isinstance(args[2], QuantizedTensor): - a[2] = args[2].dequantize() - - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout") -def fp8_mm(func, args, kwargs): - input_tensor = args[0] - weight = args[1] - - if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): - return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) - - a = list(args) - if isinstance(args[0], QuantizedTensor): - a[0] = args[0].dequantize() - if isinstance(args[1], QuantizedTensor): - a[1] = args[1].dequantize() - return func(*a, **kwargs) - -@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") -@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") -def fp8_func(func, args, kwargs): - input_tensor = args[0] - if isinstance(input_tensor, QuantizedTensor): - plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) - ar = list(args) - ar[0] = plain_input - return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) - return func(*args, **kwargs) +__all__ = [ + "QuantizedTensor", + "QuantizedLayout", + "TensorCoreFP8Layout", + "TensorCoreFP8E4M3Layout", + "TensorCoreFP8E5M2Layout", + "TensorCoreNVFP4Layout", + "QUANT_ALGOS", + "register_layout_op", +] diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py index 555542a46..ab7cf14fa 100644 --- a/comfy/rmsnorm.py +++ b/comfy/rmsnorm.py @@ -1,57 +1,10 @@ import torch import comfy.model_management -import numbers -import logging - -RMSNorm = None - -try: - rms_norm_torch = torch.nn.functional.rms_norm - RMSNorm = torch.nn.RMSNorm -except: - rms_norm_torch = None - logging.warning("Please update pytorch to use native RMSNorm") +RMSNorm = torch.nn.RMSNorm def rms_norm(x, weight=None, eps=1e-6): - if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()): - if weight is None: - return rms_norm_torch(x, (x.shape[-1],), eps=eps) - else: - return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) + if weight is None: + return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) else: - r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) - if weight is None: - return r - else: - return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device) - - -if RMSNorm is None: - class RMSNorm(torch.nn.Module): - def __init__( - self, - normalized_shape, - eps=1e-6, - elementwise_affine=True, - device=None, - dtype=None, - ): - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - if isinstance(normalized_shape, numbers.Integral): - # mypy error: incompatible types in assignment - normalized_shape = (normalized_shape,) # type: ignore[assignment] - self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] - self.eps = eps - self.elementwise_affine = elementwise_affine - if self.elementwise_affine: - self.weight = torch.nn.Parameter( - torch.empty(self.normalized_shape, **factory_kwargs) - ) - else: - self.register_parameter("weight", None) - self.bias = None - - def forward(self, x): - return rms_norm(x, self.weight, self.eps) + return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) diff --git a/comfy/sample.py b/comfy/sample.py index 2f8f3a51c..653829582 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -8,12 +8,12 @@ import comfy.nested_tensor def prepare_noise_inner(latent_image, generator, noise_inds=None): if noise_inds is None: - return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype) unique_inds, inverse = np.unique(noise_inds, return_inverse=True) noises = [] for i in range(unique_inds[-1]+1): - noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype) if i in unique_inds: noises.append(noise) noises = [noises[i] for i in inverse] @@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None): return noises -def fix_empty_latent_channels(model, latent_image): +def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None): if latent_image.is_nested: return latent_image latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels - if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + if torch.count_nonzero(latent_image) == 0: + if latent_format.latent_channels != latent_image.shape[1]: + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + if downscale_ratio_spacial is not None: + if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: + ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio + latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled") + if latent_format.latent_dimensions == 3 and latent_image.ndim == 4: latent_image = latent_image.unsqueeze(2) return latent_image @@ -58,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) + samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) + samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return samples diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py index 9134e6d71..bbba09e26 100644 --- a/comfy/sampler_helpers.py +++ b/comfy/sampler_helpers.py @@ -66,6 +66,18 @@ def convert_cond(cond): out.append(temp) return out +def cond_has_hooks(cond): + for c in cond: + temp = c[1] + if "hooks" in temp: + return True + if "control" in temp: + control = temp["control"] + extra_hooks = control.get_extra_hooks() + if len(extra_hooks) > 0: + return True + return False + def get_additional_models(conds, dtype): """loads additional models in conditioning""" cnets: list[ControlBase] = [] @@ -122,20 +134,26 @@ def estimate_memory(model, noise_shape, conds): minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min) return memory_required, minimum_memory_required -def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): +def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): executor = comfy.patcher_extension.WrapperExecutor.new_executor( _prepare_sampling, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True) ) - return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load) + return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload) -def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False): +def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False): real_model: BaseModel = None models, inference_memory = get_additional_models(conds, model.model_dtype()) models += get_additional_models_from_model_options(model_options) models += model.get_nested_additional_models() # TODO: does this require inference_memory update? - memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) - comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load) + if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load + memory_required = 1e20 + minimum_memory_required = None + else: + memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds) + memory_required += inference_memory + minimum_memory_required += inference_memory + comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load) real_model = model.model return real_model, conds, models diff --git a/comfy/samplers.py b/comfy/samplers.py index 8340d376c..0a4d062db 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -9,7 +9,6 @@ if TYPE_CHECKING: import torch from functools import partial import collections -from comfy import model_management import math import logging import comfy.sampler_helpers @@ -260,7 +259,7 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens to_batch_temp.reverse() to_batch = to_batch_temp[:1] - free_memory = model_management.get_free_memory(x_in.device) + free_memory = model.current_patcher.get_free_memory(x_in.device) for i in range(1, len(to_batch_temp) + 1): batch_amount = to_batch_temp[:len(to_batch_temp)//i] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] @@ -947,6 +946,8 @@ class CFGGuider: def inner_set_conds(self, conds): for k in conds: + if self.model_patcher.is_dynamic() and comfy.sampler_helpers.cond_has_hooks(conds[k]): + self.model_patcher = self.model_patcher.get_non_dynamic_delegate() self.original_conds[k] = comfy.sampler_helpers.convert_cond(conds[k]) def __call__(self, *args, **kwargs): @@ -984,11 +985,8 @@ class CFGGuider: self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device - if denoise_mask is not None: - denoise_mask = comfy.sampler_helpers.prepare_mask(denoise_mask, noise.shape, device) - - noise = noise.to(device) - latent_image = latent_image.to(device) + noise = noise.to(device=device, dtype=torch.float32) + latent_image = latent_image.to(device=device, dtype=torch.float32) sigmas = sigmas.to(device) cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) @@ -1013,6 +1011,25 @@ class CFGGuider: else: latent_shapes = [latent_image.shape] + if denoise_mask is not None: + if denoise_mask.is_nested: + denoise_masks = denoise_mask.unbind() + denoise_masks = denoise_masks[:len(latent_shapes)] + else: + denoise_masks = [denoise_mask] + + for i in range(len(denoise_masks), len(latent_shapes)): + denoise_masks.append(torch.ones(latent_shapes[i])) + + for i in range(len(denoise_masks)): + denoise_masks[i] = comfy.sampler_helpers.prepare_mask(denoise_masks[i], latent_shapes[i], self.model_patcher.load_device) + + if len(denoise_masks) > 1: + denoise_mask, _ = comfy.utils.pack_latents(denoise_masks) + else: + denoise_mask = denoise_masks[0] + denoise_mask = denoise_mask.float() + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) diff --git a/comfy/sd.py b/comfy/sd.py index c2a9728f3..e207bb0fd 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -20,6 +20,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert +import comfy.weight_adapter import yaml import math import os @@ -55,6 +56,11 @@ import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.newbie +import comfy.text_encoders.anima +import comfy.text_encoders.ace15 +import comfy.text_encoders.longcat_image import comfy.model_patcher import comfy.lora @@ -98,8 +104,107 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): return (new_modelpatcher, new_clip) +def load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip): + """ + Load LoRA in bypass mode without modifying base model weights. + + Instead of patching weights, this injects the LoRA computation into the + forward pass: output = base_forward(x) + lora_path(x) + + Non-adapter patches (bias diff, weight diff, etc.) are applied as regular patches. + + This is useful for training and when model weights are offloaded. + """ + key_map = {} + if model is not None: + key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) + if clip is not None: + key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) + + logging.debug(f"[BypassLoRA] key_map has {len(key_map)} entries") + + lora = comfy.lora_convert.convert_lora(lora) + loaded = comfy.lora.load_lora(lora, key_map) + + logging.debug(f"[BypassLoRA] loaded has {len(loaded)} entries") + + # Separate adapters (for bypass) from other patches (for regular patching) + bypass_patches = {} # WeightAdapterBase instances -> bypass mode + regular_patches = {} # diff, set, bias patches -> regular weight patching + + for key, patch_data in loaded.items(): + if isinstance(patch_data, comfy.weight_adapter.WeightAdapterBase): + bypass_patches[key] = patch_data + else: + regular_patches[key] = patch_data + + logging.debug(f"[BypassLoRA] {len(bypass_patches)} bypass adapters, {len(regular_patches)} regular patches") + + k = set() + k1 = set() + + if model is not None: + new_modelpatcher = model.clone() + + # Apply regular patches (bias diff, weight diff, etc.) via normal patching + if regular_patches: + patched_keys = new_modelpatcher.add_patches(regular_patches, strength_model) + k.update(patched_keys) + + # Apply adapter patches via bypass injection + manager = comfy.weight_adapter.BypassInjectionManager() + model_sd_keys = set(new_modelpatcher.model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in model_sd_keys: + manager.add_adapter(key, adapter, strength=strength_model) + k.add(key) + else: + logging.warning(f"[BypassLoRA] Adapter key not in model state_dict: {key}") + + injections = manager.create_injections(new_modelpatcher.model) + + if manager.get_hook_count() > 0: + new_modelpatcher.set_injections("bypass_lora", injections) + else: + new_modelpatcher = None + + if clip is not None: + new_clip = clip.clone() + + # Apply regular patches to clip + if regular_patches: + patched_keys = new_clip.add_patches(regular_patches, strength_clip) + k1.update(patched_keys) + + # Apply adapter patches via bypass injection + clip_manager = comfy.weight_adapter.BypassInjectionManager() + clip_sd_keys = set(new_clip.cond_stage_model.state_dict().keys()) + + for key, adapter in bypass_patches.items(): + if key in clip_sd_keys: + clip_manager.add_adapter(key, adapter, strength=strength_clip) + k1.add(key) + + clip_injections = clip_manager.create_injections(new_clip.cond_stage_model) + if clip_manager.get_hook_count() > 0: + new_clip.patcher.set_injections("bypass_lora", clip_injections) + else: + new_clip = None + + for x in loaded: + if (x not in k) and (x not in k1): + patch_data = loaded[x] + patch_type = type(patch_data).__name__ + if isinstance(patch_data, tuple): + patch_type = f"tuple({patch_data[0]})" + logging.warning(f"NOT LOADED: {x} (type={patch_type})") + + return (new_modelpatcher, new_clip) + + class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}, disable_dynamic=False): if no_init: return params = target.params.copy() @@ -125,8 +230,11 @@ class CLIP: self.cond_stage_model.to(offload_device) logging.warning("Had to shift TE back.") + model_management.archive_model_dtypes(self.cond_stage_model) + self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + self.patcher = ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) #Match torch.float32 hardcode upcast in TE implemention self.patcher.set_model_compute_dtype(torch.float32) self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram @@ -160,9 +268,9 @@ class CLIP: logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) self.tokenizer_options = {} - def clone(self): + def clone(self, disable_dynamic=False): n = CLIP(no_init=True) - n.patcher = self.patcher.clone() + n.patcher = self.patcher.clone(disable_dynamic=disable_dynamic) n.cond_stage_model = self.cond_stage_model n.tokenizer = self.tokenizer n.layer_idx = self.layer_idx @@ -216,7 +324,7 @@ class CLIP: if unprojected: self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) @@ -264,7 +372,7 @@ class CLIP: if return_pooled == "unprojected": self.cond_stage_model.set_clip_options({"projected_pooled": False}) - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] @@ -286,8 +394,18 @@ class CLIP: def load_sd(self, sd, full_model=False): if full_model: - return self.cond_stage_model.load_state_dict(sd, strict=False) + return self.cond_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) else: + can_assign = self.patcher.is_dynamic() + self.cond_stage_model.can_assign_sd = can_assign + + # The CLIP models are a pretty complex web of wrappers and its + # a bit of an API change to plumb this all the way through. + # So spray paint the model with this flag that the loading + # nn.Module can then inspect for itself. + for m in self.cond_stage_model.modules(): + m.can_assign_sd = can_assign + return self.cond_stage_model.load_sd(sd) def get_sd(self): @@ -297,13 +415,27 @@ class CLIP: sd_clip[k] = sd_tokenizer[k] return sd_clip - def load_model(self): - model_management.load_model_gpu(self.patcher) + def load_model(self, tokens={}): + memory_used = 0 + if hasattr(self.cond_stage_model, "memory_estimation_function"): + memory_used = self.cond_stage_model.memory_estimation_function(tokens, device=self.patcher.load_device) + model_management.load_models_gpu([self.patcher], memory_required=memory_used) return self.patcher def get_key_patches(self): return self.patcher.get_key_patches() + def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): + self.cond_stage_model.reset_clip_options() + + self.load_model(tokens) + self.cond_stage_model.set_clip_options({"layer": None}) + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) + return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) + + def decode(self, token_ids, skip_special_tokens=True): + return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format @@ -323,7 +455,7 @@ class VAE: self.output_channels = 3 self.pad_channel_value = None self.process_input = lambda image: image * 2.0 - 1.0 - self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False self.not_video = False @@ -334,6 +466,8 @@ class VAE: self.extra_1d_channel = None self.crop_input = True + self.audio_sample_rate = 44100 + if config is None: if "decoder.mid.block_1.mix_factor" in sd: encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -431,14 +565,27 @@ class VAE: encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig}, decoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Decoder", 'params': ddconfig}) elif "decoder.layers.1.layers.0.beta" in sd: - self.first_stage_model = AudioOobleckVAE() + config = {} + param_key = None + self.upscale_ratio = 2048 + self.downscale_ratio = 2048 + if "decoder.layers.2.layers.1.weight_v" in sd: + param_key = "decoder.layers.2.layers.1.weight_v" + if "decoder.layers.2.layers.1.parametrizations.weight.original1" in sd: + param_key = "decoder.layers.2.layers.1.parametrizations.weight.original1" + if param_key is not None: + if sd[param_key].shape[-1] == 12: + config["strides"] = [2, 4, 4, 6, 10] + self.audio_sample_rate = 48000 + self.upscale_ratio = 1920 + self.downscale_ratio = 1920 + + self.first_stage_model = AudioOobleckVAE(**config) self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * 2048) * model_management.dtype_size(dtype) self.latent_channels = 64 self.output_channels = 2 self.pad_channel_value = "replicate" - self.upscale_ratio = 2048 - self.downscale_ratio = 2048 self.latent_dim = 1 self.process_output = lambda audio: audio self.process_input = lambda audio: audio @@ -474,8 +621,8 @@ class VAE: self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config) self.latent_channels = 128 self.latent_dim = 3 - self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) self.upscale_index_formula = (8, 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) @@ -549,8 +696,9 @@ class VAE: self.latent_dim = 3 self.latent_channels = 16 self.output_channels = sd["encoder.conv1.weight"].shape[1] + self.conv_out_channels = sd["decoder.head.2.weight"].shape[0] self.pad_channel_value = 1.0 - ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "dropout": 0.0} + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "image_channels": self.output_channels, "conv_out_channels": self.conv_out_channels, "dropout": 0.0} self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype) @@ -630,14 +778,13 @@ class VAE: self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) - if self.latent_channels == 48: # Wan 2.2 + if self.latent_channels in [48, 128]: # Wan 2.2 and LTX2 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_input = self.process_output = lambda image: image self.process_output = lambda image: image self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15) - self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical @@ -660,13 +807,6 @@ class VAE: self.first_stage_model = AutoencoderKL(**(config['params'])) self.first_stage_model = self.first_stage_model.eval() - m, u = self.first_stage_model.load_state_dict(sd, strict=False) - if len(m) > 0: - logging.warning("Missing VAE keys {}".format(m)) - - if len(u) > 0: - logging.debug("Leftover VAE keys {}".format(u)) - if device is None: device = model_management.vae_device() self.device = device @@ -675,9 +815,21 @@ class VAE: dtype = model_management.vae_dtype(self.device, self.working_dtypes) self.vae_dtype = dtype self.first_stage_model.to(self.vae_dtype) + model_management.archive_model_dtypes(self.first_stage_model) self.output_device = model_management.intermediate_device() - self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) + mp = comfy.model_patcher.CoreModelPatcher + if self.disable_offload: + mp = comfy.model_patcher.ModelPatcher + self.patcher = mp(self.first_stage_model, load_device=self.device, offload_device=offload_device) + + m, u = self.first_stage_model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) + if len(m) > 0: + logging.warning("Missing VAE keys {}".format(m)) + + if len(u) > 0: + logging.debug("Leftover VAE keys {}".format(u)) + logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) self.model_size() @@ -719,13 +871,16 @@ class VAE: pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value) return pixels + def vae_output_dtype(self): + return model_management.intermediate_dtype() + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) output = self.process_output( (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + @@ -733,18 +888,18 @@ class VAE: / 3.0) return output - def decode_tiled_1d(self, samples, tile_x=128, overlap=32): + def decode_tiled_1d(self, samples, tile_x=256, overlap=32): if samples.ndim == 3: - decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) else: og_shape = samples.shape samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) - decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float() + decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): - decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -753,7 +908,7 @@ class VAE: steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -762,7 +917,7 @@ class VAE: def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): if self.latent_dim == 1: - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) out_channels = self.latent_channels upscale_amount = 1 / self.downscale_ratio else: @@ -771,7 +926,7 @@ class VAE: tile_x = tile_x // extra_channel_size overlap = overlap // extra_channel_size upscale_amount = 1 / self.downscale_ratio - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype()) out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) if self.latent_dim == 1: @@ -780,7 +935,7 @@ class VAE: return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1) def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in, vae_options={}): @@ -792,17 +947,29 @@ class VAE: try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) + # Pre-allocate output for VAEs that support direct buffer writes + preallocated = False + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) + preallocated = True + for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) - pixel_samples[x:x+batch_number] = out - except model_management.OOM_EXCEPTION: + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + if preallocated: + self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) + else: + out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + pixel_samples[x:x+batch_number].copy_(out) + del out + self.process_output(pixel_samples[x:x+batch_number]) + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. @@ -811,6 +978,7 @@ class VAE: do_tile = True if do_tile: + comfy.model_management.soft_empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -837,7 +1005,7 @@ class VAE: if overlap is not None: args["overlap"] = overlap - if dims == 1: + if dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) elif dims == 2: @@ -866,18 +1034,24 @@ class VAE: try: memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) - free_memory = model_management.get_free_memory(self.device) + free_memory = self.patcher.get_free_memory(self.device) batch_number = int(free_memory / max(1, memory_used)) batch_number = max(1, batch_number) samples = None for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) - out = self.first_stage_model.encode(pixels_in).to(self.output_device).float() + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: - samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) + samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) samples[x:x + batch_number] = out - except model_management.OOM_EXCEPTION: + except Exception as e: + model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") #NOTE: We don't know what tensors were allocated to stack variables at the time of the #exception and the exception itself refs them all until we get out of this except block. @@ -886,6 +1060,7 @@ class VAE: do_tile = True if do_tile: + comfy.model_management.soft_empty_cache() if self.latent_dim == 3: tile = 256 overlap = tile // 4 @@ -1008,16 +1183,26 @@ class CLIPType(Enum): OVIS = 21 KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 + NEWBIE = 24 + FLUX2 = 25 + LONGCAT_IMAGE = 26 -def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): + +def load_clip_model_patcher(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): + clip = load_clip(ckpt_paths, embedding_directory, clip_type, model_options, disable_dynamic) + return clip.patcher + +def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) if model_options.get("custom_operations", None) is None: sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) - return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) + clip = load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, disable_dynamic=disable_dynamic) + clip.patcher.cached_patcher_init = (load_clip_model_patcher, (ckpt_paths, embedding_directory, clip_type, model_options)) + return clip class TEModel(Enum): @@ -1038,6 +1223,11 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 + GEMMA_3_12B = 18 + JINA_CLIP_2 = 19 + QWEN3_8B = 20 + QWEN3_06B = 21 + GEMMA_3_4B_VISION = 22 def detect_te_model(sd): @@ -1047,11 +1237,13 @@ def detect_te_model(sd): return TEModel.CLIP_H if "text_model.encoder.layers.0.mlp.fc1.weight" in sd: return TEModel.CLIP_L + if "model.encoder.layers.0.mixer.Wqkv.weight" in sd: + return TEModel.JINA_CLIP_2 if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd: weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] - if weight.shape[-1] == 4096: + if weight.shape[0] == 10240: return TEModel.T5_XXL - elif weight.shape[-1] == 2048: + elif weight.shape[0] == 5120: return TEModel.T5_XL if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd: return TEModel.T5_XXL_OLD @@ -1061,8 +1253,13 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.47.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: - return TEModel.GEMMA_3_4B + if 'vision_model.embeddings.patch_embedding.weight' in sd: + return TEModel.GEMMA_3_4B_VISION + else: + return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B if 'model.layers.0.self_attn.k_proj.bias' in sd: weight = sd['model.layers.0.self_attn.k_proj.bias'] @@ -1077,6 +1274,10 @@ def detect_te_model(sd): return TEModel.QWEN3_4B elif weight.shape[0] == 2048: return TEModel.QWEN3_2B + elif weight.shape[0] == 4096: + return TEModel.QWEN3_8B + elif weight.shape[0] == 1024: + return TEModel.QWEN3_06B if weight.shape[0] == 5120: if "model.layers.39.post_attention_layernorm.weight" in sd: return TEModel.MISTRAL3_24B @@ -1106,7 +1307,7 @@ def llama_detect(clip_data): return {} -def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): +def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, disable_dynamic=False): clip_data = state_dicts class EmptyClass: @@ -1118,6 +1319,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: if "text_projection" in clip_data[i]: clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node + if "lm_head.weight" in clip_data[i]: + clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models tokenizer_data = {} clip_target = EmptyClass() @@ -1183,6 +1386,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.GEMMA_3_4B_VISION: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision") + clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif te_model == TEModel.GEMMA_3_12B: + clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) @@ -1194,6 +1405,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip if clip_type == CLIPType.HUNYUAN_IMAGE: clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + elif clip_type == CLIPType.LONGCAT_IMAGE: + clip_target.clip = comfy.text_encoders.longcat_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.longcat_image.LongCatImageTokenizer else: clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer @@ -1202,11 +1416,24 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) elif te_model == TEModel.QWEN3_4B: - clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer + if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2: + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b") + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer + else: + clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer elif te_model == TEModel.QWEN3_2B: clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer + elif te_model == TEModel.QWEN3_8B: + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b") + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B + elif te_model == TEModel.JINA_CLIP_2: + clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper + clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper + elif te_model == TEModel.QWEN3_06B: + clip_target.clip = comfy.text_encoders.anima.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.anima.AnimaTokenizer else: # clip_l if clip_type == CLIPType.SD3: @@ -1262,6 +1489,29 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.LTXV: + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) + elif clip_type == CLIPType.NEWBIE: + clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer + if "model.layers.0.self_attn.q_norm.weight" in clip_data[0]: + clip_data_gemma = clip_data[0] + clip_data_jina = clip_data[1] + else: + clip_data_gemma = clip_data[1] + clip_data_jina = clip_data[0] + tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) + tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) + elif clip_type == CLIPType.ACE: + te_models = [detect_te_model(clip_data[0]), detect_te_model(clip_data[1])] + if TEModel.QWEN3_4B in te_models: + model_type = "qwen3_4b" + else: + model_type = "qwen3_2b" + clip_target.clip = comfy.text_encoders.ace15.te(lm_model=model_type, **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ace15.ACE15Tokenizer else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer @@ -1277,7 +1527,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options, disable_dynamic=disable_dynamic) return clip def load_gligen(ckpt_path): @@ -1285,7 +1535,7 @@ def load_gligen(ckpt_path): model = gligen.load_gligen(data) if model_management.should_use_fp16(): model = model.half() - return comfy.model_patcher.ModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) + return comfy.model_patcher.CoreModelPatcher(model, load_device=model_management.get_torch_device(), offload_device=model_management.unet_offload_device()) def model_detection_error_hint(path, state_dict): filename = os.path.basename(path) @@ -1317,14 +1567,34 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) -def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): +def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) - out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata) + out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) if out is None: raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd))) + if output_model and out[0] is not None: + out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options)) + if output_clip and out[1] is not None: + out[1].patcher.cached_patcher_init = (load_checkpoint_guess_config_clip_only, (ckpt_path, embedding_directory, model_options, te_model_options)) return out -def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None): +def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False, + embedding_directory=embedding_directory, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic) + return model + +def load_checkpoint_guess_config_clip_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False): + _, clip, *_ = load_checkpoint_guess_config(ckpt_path, False, True, False, + embedding_directory=embedding_directory, output_model=False, + model_options=model_options, + te_model_options=te_model_options, + disable_dynamic=disable_dynamic) + return clip.patcher + +def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False): clip = None clipvision = None vae = None @@ -1373,7 +1643,9 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c if output_model: inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype) model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device) - model.load_model_weights(sd, diffusion_model_prefix) + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) + model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic()) if output_vae: vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True) @@ -1407,7 +1679,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options, disable_dynamic=disable_dynamic) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") @@ -1416,7 +1688,6 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c logging.debug("left over keys: {}".format(left_over)) if output_model: - model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) if inital_load_device != torch.device("cpu"): logging.info("loaded diffusion model directly to GPU") model_management.load_models_gpu([model_patcher], force_full_load=True) @@ -1424,7 +1695,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (model_patcher, clip, vae, clipvision) -def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): +def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False): """ Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. @@ -1508,20 +1779,23 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - model = model.to(offload_device) - model.load_model_weights(new_sd, "") + ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher + model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) + if not model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic()) left_over = sd.keys() if len(left_over) > 0: logging.info("left over keys in diffusion model: {}".format(left_over)) - return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) + return model_patcher - -def load_diffusion_model(unet_path, model_options={}): +def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True) - model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata) + model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic) if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) + model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options)) return model def load_unet(unet_path, dtype=None): @@ -1545,9 +1819,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if metadata is None: metadata = {} - model_management.load_models_gpu(load_models, force_patch_weights=True) + model_management.load_models_gpu(load_models) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None - sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) + sd = model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) for k in extra_keys: sd[k] = extra_keys[k] diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 962948dae..0eb30df27 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -46,7 +46,7 @@ class ClipTokenWeightEncoder: out, pooled = o[:2] if pooled is not None: - first_pooled = pooled[0:1].to(model_management.intermediate_device()) + first_pooled = pooled[0:1].to(device=model_management.intermediate_device()) else: first_pooled = pooled @@ -63,16 +63,16 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - r = (out[-1:].to(model_management.intermediate_device()), first_pooled) + r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled) else: - r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) + r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled) if len(o) > 2: extra = {} for k in o[2]: v = o[2][k] if k == "attention_mask": - v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()) + v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device()) extra[k] = v r = r + (extra,) @@ -155,6 +155,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.execution_device = options.get("execution_device", self.execution_device) if isinstance(self.layer, list) or self.layer == "all": pass + elif isinstance(layer_idx, list): + self.layer = layer_idx elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" else: @@ -169,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def process_tokens(self, tokens, device): end_token = self.special_tokens.get("end", None) + pad_token = self.special_tokens.get("pad", -1) if end_token is None: - cmp_token = self.special_tokens.get("pad", -1) + cmp_token = pad_token else: cmp_token = end_token @@ -184,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): other_embeds = [] eos = False index = 0 + left_pad = False for y in x: if isinstance(y, numbers.Integral): - if eos: + token = int(y) + if index == 0 and token == pad_token: + left_pad = True + + if eos or (left_pad and token == pad_token): attention_mask.append(0) else: attention_mask.append(1) - token = int(y) + left_pad = False + tokens_temp += [token] - if not eos and token == cmp_token: + if not eos and token == cmp_token and not left_pad: if end_token is None: attention_mask[-1] = 0 eos = True @@ -297,7 +306,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return self(tokens) def load_sd(self, sd): - return self.transformer.load_state_dict(sd, strict=False) + return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) + + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + if isinstance(tokens, dict): + tokens_only = next(iter(tokens.values())) # todo: get this better? + else: + tokens_only = tokens + tokens_only = [[t[0] for t in b] for b in tokens_only] + embeds = self.process_tokens(tokens_only, device=self.execution_device)[0] + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) def parse_parentheses(string): result = [] @@ -466,7 +484,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No return embed_out class SDTokenizer: - def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}): + def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, start_token=None, min_padding=None, pad_left=False, disable_weights=False, tokenizer_data={}, tokenizer_args={}): if tokenizer_path is None: tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) @@ -479,8 +497,15 @@ class SDTokenizer: empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token if has_start_token: - self.tokens_start = 1 - self.start_token = empty[0] + if len(empty) > 0: + self.tokens_start = 1 + self.start_token = empty[0] + else: + self.tokens_start = 0 + self.start_token = start_token + if start_token is None: + logging.warning("WARNING: There's something wrong with your tokenizers.'") + if end_token is not None: self.end_token = end_token else: @@ -488,7 +513,7 @@ class SDTokenizer: self.end_token = empty[1] else: self.tokens_start = 0 - self.start_token = None + self.start_token = start_token if end_token is not None: self.end_token = end_token else: @@ -513,6 +538,8 @@ class SDTokenizer: self.embedding_size = embedding_size self.embedding_key = embedding_key + self.disable_weights = disable_weights + def _try_get_embedding(self, embedding_name:str): ''' Takes a potential embedding name and tries to retrieve it. @@ -546,8 +573,10 @@ class SDTokenizer: min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length) min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding) + min_length = kwargs.get("min_length", min_length) + text = escape_important(text) - if kwargs.get("disable_weights", False): + if kwargs.get("disable_weights", self.disable_weights): parsed_weights = [(text, 1.0)] else: parsed_weights = token_weights(text, 1.0) @@ -645,6 +674,9 @@ class SDTokenizer: def state_dict(self): return {} + def decode(self, token_ids, skip_special_tokens=True): + return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + class SD1Tokenizer: def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None): if name is not None: @@ -668,6 +700,9 @@ class SD1Tokenizer: def state_dict(self): return getattr(self, self.clip).state_dict() + def decode(self, token_ids, skip_special_tokens=True): + return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens) + class SD1CheckpointClipModel(SDClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options) @@ -704,3 +739,6 @@ class SD1ClipModel(torch.nn.Module): def load_sd(self, sd): return getattr(self, self.clip).load_sd(sd) + + def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): + return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1888f35ba..07feb31b3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -23,6 +23,9 @@ import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image +import comfy.text_encoders.anima +import comfy.text_encoders.ace15 +import comfy.text_encoders.longcat_image from . import supported_models_base from . import latent_formats @@ -523,7 +526,8 @@ class LotusD(SD20): } unet_extra_config = { - "num_classes": 'sequential' + "num_classes": 'sequential', + "num_head_channels": 64, } def get_model(self, state_dict, prefix="", device=None): @@ -708,6 +712,15 @@ class Flux(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith("_norm.scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd + vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] @@ -763,17 +776,31 @@ class Flux2(Flux): def __init__(self, unet_config): super().__init__(unet_config) - self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * (unet_config['hidden_size'] / 2604) def get_model(self, state_dict, prefix="", device=None): out = model_base.Flux2(self, device=device) return out def clip_target(self, state_dict={}): - return None # TODO pref = self.text_encoder_key_prefix[0] - t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) - return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + if len(detect) > 0: + detect["model_type"] = "qwen3_4b" + return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer, comfy.text_encoders.flux.klein_te(**detect)) + + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_8b.transformer.".format(pref)) + if len(detect) > 0: + detect["model_type"] = "qwen3_8b" + return supported_models_base.ClipTarget(comfy.text_encoders.flux.KleinTokenizer8B, comfy.text_encoders.flux.klein_te(**detect)) + + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}mistral3_24b.transformer.".format(pref)) + if len(detect) > 0: + if "{}mistral3_24b.transformer.model.layers.39.post_attention_layernorm.weight".format(pref) not in state_dict: + detect["pruned"] = True + return supported_models_base.ClipTarget(comfy.text_encoders.flux.Flux2Tokenizer, comfy.text_encoders.flux.flux2_te(**detect)) + + return None class GenmoMochi(supported_models_base.BASE): unet_config = { @@ -836,6 +863,21 @@ class LTXV(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect)) +class LTXAV(LTXV): + unet_config = { + "image_model": "ltxav", + } + + latent_format = latent_formats.LTXAV + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = 0.077 # TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LTXAV(self, device=device) + return out + class HunyuanVideo(supported_models_base.BASE): unet_config = { "image_model": "hunyuan_video", @@ -867,11 +909,13 @@ class HunyuanVideo(supported_models_base.BASE): key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.") key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.") key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.") - key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale") - key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale") + key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight") + key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight") key_out = key_out.replace("_attn_proj.", "_attn.proj.") key_out = key_out.replace(".modulation.linear.", ".modulation.lin.") key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.") + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) out_sd[key_out] = state_dict[k] return out_sd @@ -962,7 +1006,7 @@ class CosmosT2IPredict2(supported_models_base.BASE): memory_usage_factor = 1.0 - supported_inference_dtypes = [torch.bfloat16, torch.float32] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] def __init__(self, unet_config): super().__init__(unet_config) @@ -977,6 +1021,38 @@ class CosmosT2IPredict2(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect)) +class Anima(supported_models_base.BASE): + unet_config = { + "image_model": "anima", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + unet_extra_config = {} + latent_format = latent_formats.Wan21 + + memory_usage_factor = 1.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Anima(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_06b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.anima.AnimaTokenizer, comfy.text_encoders.anima.te(**detect)) + + def set_inference_dtype(self, dtype, manual_cast_dtype, **kwargs): + self.memory_usage_factor = (self.unet_config.get("model_channels", 2048) / 2048) * 0.95 + if dtype is torch.float16: + self.memory_usage_factor *= 1.4 + return super().set_inference_dtype(dtype, manual_cast_dtype, **kwargs) + class CosmosI2VPredict2(CosmosT2IPredict2): unet_config = { "image_model": "cosmos_predict2", @@ -1027,13 +1103,13 @@ class ZImage(Lumina2): "shift": 3.0, } - memory_usage_factor = 2.0 + memory_usage_factor = 2.8 supported_inference_dtypes = [torch.bfloat16, torch.float32] def __init__(self, unet_config): super().__init__(unet_config) - if comfy.model_management.extended_fp16_support(): + if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False): self.supported_inference_dtypes = self.supported_inference_dtypes.copy() self.supported_inference_dtypes.insert(1, torch.float16) @@ -1042,6 +1118,20 @@ class ZImage(Lumina2): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) +class ZImagePixelSpace(ZImage): + unet_config = { + "image_model": "zimage_pixel", + } + + # Pixel-space model: no spatial compression, operates on raw RGB patches. + latent_format = latent_formats.ZImagePixelSpace + + # Much lower memory than latent-space models (no VAE, small patches). + memory_usage_factor = 0.03 # TODO: figure out the optimal value for this. + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ZImagePixelSpace(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1182,6 +1272,26 @@ class WAN22_T2V(WAN21_T2V): out = model_base.WAN22(self, image_to_video=True, device=device) return out +class WAN21_FlowRVS(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "flow_rvs", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_FlowRVS(self, image_to_video=True, device=device) + return out + +class WAN21_SCAIL(WAN21_T2V): + unet_config = { + "image_model": "wan2.1", + "model_type": "scail", + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) + return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1201,6 +1311,15 @@ class Hunyuan3Dv2(supported_models_base.BASE): latent_format = latent_formats.Hunyuan3Dv2 + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd + def process_unet_state_dict_for_saving(self, state_dict): replace_prefix = {"": "model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -1278,6 +1397,14 @@ class Chroma(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd def get_model(self, state_dict, prefix="", device=None): out = model_base.Chroma(self, device=device) @@ -1536,6 +1663,77 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +class ACEStep15(supported_models_base.BASE): + unet_config = { + "audio_model": "ace1.5", + } + + unet_extra_config = { + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + latent_format = comfy.latent_formats.ACEAudio15 + + memory_usage_factor = 4.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.ACEStep15(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + detect_2b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_2b.transformer.".format(pref)) + detect_4b = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + if "dtype_llama" in detect_2b: + detect = detect_2b + detect["lm_model"] = "qwen3_2b" + elif "dtype_llama" in detect_4b: + detect = detect_4b + detect["lm_model"] = "qwen3_4b" + + return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) + + +class LongCatImage(supported_models_base.BASE): + unet_config = { + "image_model": "flux", + "guidance_embed": False, + "vec_in_dim": None, + "context_in_dim": 3584, + "txt_ids_dims": [1, 2], + } + + sampling_settings = { + } + + unet_extra_config = {} + latent_format = latent_formats.Flux + + memory_usage_factor = 2.5 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LongCatImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py index 3dfe1e4d4..6c06ce19d 100644 --- a/comfy/taesd/taehv.py +++ b/comfy/taesd/taehv.py @@ -112,7 +112,8 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar): class TAEHV(nn.Module): - def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + def __init__(self, latent_channels, parallel=False, encoder_time_downscale=(True, True, False), decoder_time_upscale=(False, True, True), decoder_space_upscale=(True, True, True), + latent_format=None, show_progress_bar=False): super().__init__() self.image_channels = 3 self.patch_size = 1 @@ -124,6 +125,9 @@ class TAEHV(nn.Module): self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 self.patch_size = 2 + elif self.latent_channels == 128: # LTX2 + self.patch_size, self.latent_channels, encoder_time_downscale, decoder_time_upscale = 4, 128, (True, True, True), (True, True, True) + if self.latent_channels == 32: # HunyuanVideo1.5 act_func = nn.LeakyReLU(0.2, inplace=True) else: # HunyuanVideo, Wan 2.1 @@ -131,41 +135,54 @@ class TAEHV(nn.Module): self.encoder = nn.Sequential( conv(self.image_channels*self.patch_size**2, 64), act_func, - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), - TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[0] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[1] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2 if encoder_time_downscale[2] else 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), conv(64, self.latent_channels), ) n_f = [256, 128, 64, 64] - self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( Clamp(), conv(self.latent_channels, n_f[0]), act_func, - MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), - MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), - MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1), conv(n_f[2], n_f[3], bias=False), act_func, conv(n_f[3], self.image_channels*self.patch_size**2), ) - @property - def show_progress_bar(self): - return self._show_progress_bar - @show_progress_bar.setter - def show_progress_bar(self, value): - self._show_progress_bar = value + self.t_downscale = 2**sum(t.stride == 2 for t in self.encoder if isinstance(t, TPool)) + self.t_upscale = 2**sum(t.stride == 2 for t in self.decoder if isinstance(t, TGrow)) + self.frames_to_trim = self.t_upscale - 1 + self._show_progress_bar = show_progress_bar + + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value def encode(self, x, **kwargs): - if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] - if x.shape[1] % 4 != 0: - # pad at end to multiple of 4 - n_pad = 4 - x.shape[1] % 4 + if self.patch_size > 1: + B, T, C, H, W = x.shape + x = x.reshape(B * T, C, H, W) + x = F.pixel_unshuffle(x, self.patch_size) + x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size) + if x.shape[1] % self.t_downscale != 0: + # pad at end to multiple of t_downscale + n_pad = self.t_downscale - x.shape[1] % self.t_downscale padding = x[:, -1:].repeat_interleave(n_pad, dim=1) x = torch.cat([x, padding], 1) x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) return self.process_out(x) def decode(self, x, **kwargs): + x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] + x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) - if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + if self.patch_size > 1: + x = F.pixel_shuffle(x, self.patch_size) return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py new file mode 100644 index 000000000..853f021ae --- /dev/null +++ b/comfy/text_encoders/ace15.py @@ -0,0 +1,348 @@ +from .anima import Qwen3Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import torch +import math +import yaml +import comfy.utils + + +def sample_manual_loop_no_classes( + model, + ids=None, + execution_dtype=None, + cfg_scale: float = 2.0, + temperature: float = 0.85, + top_p: float = 0.9, + top_k: int = None, + min_p: float = 0.000, + seed: int = 1, + min_tokens: int = 1, + max_new_tokens: int = 2048, + audio_start_id: int = 151669, # The cutoff ID for audio codes + audio_end_id: int = 215669, + eos_token_id: int = 151645, +): + if ids is None: + return [] + device = model.execution_device + + if execution_dtype is None: + if comfy.model_management.should_use_bf16(device): + execution_dtype = torch.bfloat16 + else: + execution_dtype = torch.float32 + + embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device) + embeds_batch = embeds.shape[0] + + output_audio_codes = [] + past_key_values = [] + generator = torch.Generator(device=device) + generator.manual_seed(seed) + model_config = model.transformer.model.config + past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim] + + for x in range(model_config.num_hidden_layers): + past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0)) + + progress_bar = comfy.utils.ProgressBar(max_new_tokens) + + for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"): + outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values) + next_token_logits = model.transformer.logits(outputs[0])[:, -1] + past_key_values = outputs[2] + + if cfg_scale != 1.0: + cond_logits = next_token_logits[0:1] + uncond_logits = next_token_logits[1:2] + cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits) + else: + cfg_logits = next_token_logits[0:1] + + use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step + if use_eos_score: + eos_score = cfg_logits[:, eos_token_id].clone() + + remove_logit_value = torch.finfo(cfg_logits.dtype).min + # Only generate audio tokens + cfg_logits[:, :audio_start_id] = remove_logit_value + cfg_logits[:, audio_end_id:] = remove_logit_value + + if use_eos_score: + cfg_logits[:, eos_token_id] = eos_score + + if top_k is not None and top_k > 0: + top_k_vals, _ = torch.topk(cfg_logits, top_k) + min_val = top_k_vals[..., -1, None] + cfg_logits[cfg_logits < min_val] = remove_logit_value + + if min_p is not None and min_p > 0: + probs = torch.softmax(cfg_logits, dim=-1) + p_max = probs.max(dim=-1, keepdim=True).values + indices_to_remove = probs < (min_p * p_max) + cfg_logits[indices_to_remove] = remove_logit_value + + if top_p is not None and top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + cfg_logits[indices_to_remove] = remove_logit_value + + if temperature > 0: + cfg_logits = cfg_logits / temperature + next_token = torch.multinomial(torch.softmax(cfg_logits, dim=-1), num_samples=1, generator=generator).squeeze(1) + else: + next_token = torch.argmax(cfg_logits, dim=-1) + + token = next_token.item() + + if token == eos_token_id: + break + + embed, _, _, _ = model.process_tokens([[token]], device) + embeds = embed.repeat(embeds_batch, 1, 1) + attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1) + + output_audio_codes.append(token - audio_start_id) + progress_bar.update_absolute(step) + + return output_audio_codes + + +def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000): + positive = [[token for token, _ in inner_list] for inner_list in positive] + positive = positive[0] + + if cfg_scale != 1.0: + negative = [[token for token, _ in inner_list] for inner_list in negative] + negative = negative[0] + + neg_pad = 0 + if len(negative) < len(positive): + neg_pad = (len(positive) - len(negative)) + negative = [model.special_tokens["pad"]] * neg_pad + negative + + pos_pad = 0 + if len(negative) > len(positive): + pos_pad = (len(negative) - len(positive)) + positive = [model.special_tokens["pad"]] * pos_pad + positive + + ids = [positive, negative] + else: + ids = [positive] + + return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + + +class ACE15Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_06b", tokenizer=Qwen3Tokenizer) + + def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str: + user_metas = { + k: kwargs.pop(k) + for k in ("bpm", "duration", "keyscale", "timesignature") + if k in kwargs + } + timesignature = user_metas.get("timesignature") + if isinstance(timesignature, str) and timesignature.endswith("/4"): + user_metas["timesignature"] = timesignature[:-2] + user_metas = { + k: v if not isinstance(v, str) or not v.isdigit() else int(v) + for k, v in user_metas.items() + if v not in {"unspecified", None} + } + if len(user_metas): + meta_yaml = yaml.dump(user_metas, allow_unicode=True, sort_keys=True).strip() + else: + meta_yaml = "" + return f"\n{meta_yaml}\n" if not return_yaml else meta_yaml + + def _metas_to_cap(self, **kwargs) -> str: + use_keys = ("bpm", "timesignature", "keyscale", "duration") + user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys } + timesignature = user_metas.get("timesignature") + if isinstance(timesignature, str) and timesignature.endswith("/4"): + user_metas["timesignature"] = timesignature[:-2] + duration = user_metas["duration"] + if duration == "N/A": + user_metas["duration"] = "30 seconds" + elif isinstance(duration, (str, int, float)): + user_metas["duration"] = f"{math.ceil(float(duration))} seconds" + else: + raise TypeError("Unexpected type for duration key, must be str, int or float") + return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + text = text.strip() + text_negative = kwargs.get("caption_negative", text).strip() + lyrics = kwargs.get("lyrics", "") + lyrics_negative = kwargs.get("lyrics_negative", lyrics) + duration = kwargs.get("duration", 120) + if isinstance(duration, str): + duration = float(duration.split(None, 1)[0]) + language = kwargs.get("language") + seed = kwargs.get("seed", 0) + + generate_audio_codes = kwargs.get("generate_audio_codes", True) + cfg_scale = kwargs.get("cfg_scale", 2.0) + temperature = kwargs.get("temperature", 0.85) + top_p = kwargs.get("top_p", 0.9) + top_k = kwargs.get("top_k", 0.0) + min_p = kwargs.get("min_p", 0.000) + + duration = math.ceil(duration) + kwargs["duration"] = duration + tokens_duration = duration * 5 + min_tokens = int(kwargs.get("min_tokens", tokens_duration)) + max_tokens = int(kwargs.get("max_tokens", tokens_duration)) + + metas_negative = { + k.rsplit("_", 1)[0]: kwargs.pop(k) + for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative") + if k in kwargs + } + if not kwargs.get("use_negative_caption"): + _ = metas_negative.pop("caption", None) + + cot_text = self._metas_to_cot(caption=text, **kwargs) + cot_text_negative = "\n\n" if not metas_negative else self._metas_to_cot(**metas_negative) + meta_cap = self._metas_to_cap(**kwargs) + + lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n" + lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>" + qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>" + + llm_prompts = { + "lm_prompt": lm_template.format(text, lyrics.strip(), cot_text), + "lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative), + "lyrics": lyrics_template.format(language if language is not None else "", lyrics), + "qwen3_06b": qwen3_06b_template.format(text, meta_cap), + } + + out = { + prompt_key: self.qwen3_06b.tokenize_with_weights( + prompt, + prompt_key == "qwen3_06b" and return_word_ids, + disable_weights = True, + **kwargs, + ) + for prompt_key, prompt in llm_prompts.items() + } + out["lm_metadata"] = {"min_tokens": min_tokens, + "max_tokens": max_tokens, + "seed": seed, + "generate_audio_codes": generate_audio_codes, + "cfg_scale": cfg_scale, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + } + return out + + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B_ACE15, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Qwen3_2B_ACE15(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_2B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Qwen3_4B_ACE15(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B_ACE15_lm, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class ACE15TEModel(torch.nn.Module): + def __init__(self, device="cpu", dtype=None, dtype_llama=None, lm_model=None, model_options={}): + super().__init__() + if dtype_llama is None: + dtype_llama = dtype + + model = None + self.constant = 0.4375 + if lm_model == "qwen3_4b": + model = Qwen3_4B_ACE15 + self.constant = 0.5625 + elif lm_model == "qwen3_2b": + model = Qwen3_2B_ACE15 + + self.lm_model = lm_model + self.qwen3_06b = Qwen3_06BModel(device=device, dtype=dtype, model_options=model_options) + if model is not None: + setattr(self, self.lm_model, model(device=device, dtype=dtype_llama, model_options=model_options)) + + self.dtypes = set([dtype, dtype_llama]) + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_base = token_weight_pairs["qwen3_06b"] + token_weight_pairs_lyrics = token_weight_pairs["lyrics"] + + self.qwen3_06b.set_clip_options({"layer": None}) + base_out, _, extra = self.qwen3_06b.encode_token_weights(token_weight_pairs_base) + self.qwen3_06b.set_clip_options({"layer": [0]}) + lyrics_embeds, _, extra_l = self.qwen3_06b.encode_token_weights(token_weight_pairs_lyrics) + + out = {"conditioning_lyrics": lyrics_embeds[:, 0]} + + lm_metadata = token_weight_pairs["lm_metadata"] + if lm_metadata["generate_audio_codes"]: + audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"]) + out["audio_codes"] = [audio_codes] + + return base_out, None, out + + def set_clip_options(self, options): + self.qwen3_06b.set_clip_options(options) + lm_model = getattr(self, self.lm_model, None) + if lm_model is not None: + lm_model.set_clip_options(options) + + def reset_clip_options(self): + self.qwen3_06b.reset_clip_options() + lm_model = getattr(self, self.lm_model, None) + if lm_model is not None: + lm_model.reset_clip_options() + + def load_sd(self, sd): + if "model.layers.0.post_attention_layernorm.weight" in sd: + shape = sd["model.layers.0.post_attention_layernorm.weight"].shape + if shape[0] == 1024: + return self.qwen3_06b.load_sd(sd) + else: + return getattr(self, self.lm_model).load_sd(sd) + + def memory_estimation_function(self, token_weight_pairs, device=None): + lm_metadata = token_weight_pairs.get("lm_metadata", {}) + constant = self.constant + if comfy.model_management.should_use_bf16(device): + constant *= 0.5 + + token_weight_pairs = token_weight_pairs.get("lm_prompt", []) + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + num_tokens += lm_metadata.get("min_tokens", 0) + return num_tokens * constant * 1024 * 1024 + +def te(dtype_llama=None, llama_quantization_metadata=None, lm_model="qwen3_2b"): + class ACE15TEModel_(ACE15TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype_llama=dtype_llama, lm_model=lm_model, dtype=dtype, model_options=model_options) + return ACE15TEModel_ diff --git a/comfy/text_encoders/anima.py b/comfy/text_encoders/anima.py new file mode 100644 index 000000000..2e31b2b04 --- /dev/null +++ b/comfy/text_encoders/anima.py @@ -0,0 +1,63 @@ +from transformers import Qwen2Tokenizer, T5TokenizerFast +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch + + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='qwen3_06b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + +class AnimaTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.qwen3_06b = Qwen3Tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + qwen_ids = self.qwen3_06b.tokenize_with_weights(text, return_word_ids, **kwargs) + out["qwen3_06b"] = [[(k[0], 1.0, k[2]) if return_word_ids else (k[0], 1.0) for k in inner_list] for inner_list in qwen_ids] # Set weights to 1.0 + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + return self.t5xxl.untokenize(token_weight_pair) + + def state_dict(self): + return {} + + def decode(self, token_ids, **kwargs): + return self.qwen3_06b.decode(token_ids, **kwargs) + +class Qwen3_06BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_06B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class AnimaTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_06b", clip_model=Qwen3_06BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out = super().encode_token_weights(token_weight_pairs) + out[2]["t5xxl_ids"] = torch.tensor(list(map(lambda a: a[0], token_weight_pairs["t5xxl"][0])), dtype=torch.int) + out[2]["t5xxl_weights"] = torch.tensor(list(map(lambda a: a[1], token_weight_pairs["t5xxl"][0]))) + return out + +def te(dtype_llama=None, llama_quantization_metadata=None): + class AnimaTEModel_(AnimaTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return AnimaTEModel_ diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 448381fa9..f4b40ac68 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return CosmosTEModel_ diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 21d93d757..1ae398789 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -3,7 +3,7 @@ import comfy.text_encoders.t5 import comfy.text_encoders.sd3_clip import comfy.text_encoders.llama import comfy.model_management -from transformers import T5TokenizerFast, LlamaTokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast, Qwen2Tokenizer import torch import os import json @@ -118,7 +118,7 @@ class MistralTokenizerClass: class Mistral3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): self.tekken_data = tokenizer_data.get("tekken_model", None) - super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + super().__init__("", pad_with_end=False, embedding_directory=embedding_directory, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, start_token=1, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) def state_dict(self): return {"tekken_model": self.tekken_data} @@ -172,3 +172,60 @@ def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): model_options["num_layers"] = 30 super().__init__(device=device, dtype=dtype, model_options=model_options) return Flux2TEModel_ + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + +class Qwen3Tokenizer8B(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=4096, embedding_key='qwen3_8b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=151643, tokenizer_data=tokenizer_data) + +class KleinTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_4b"): + if name == "qwen3_4b": + tokenizer = Qwen3Tokenizer + elif name == "qwen3_8b": + tokenizer = Qwen3Tokenizer8B + + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name, tokenizer=tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class KleinTokenizer8B(KleinTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}, name="qwen3_8b"): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name=name) + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +class Qwen3_8BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=[9, 18, 27], layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_8B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + +def klein_te(dtype_llama=None, llama_quantization_metadata=None, model_type="qwen3_4b"): + if model_type == "qwen3_4b": + model = Qwen3_4BModel + elif model_type == "qwen3_8b": + model = Qwen3_8BModel + + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) + return Flux2TEModel_ diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 5daea8135..2d7a3fbce 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return MochiTEModel_ diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index a9a6c525e..2ddb4da60 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -10,9 +10,11 @@ import comfy.utils def llama_detect(state_dict, prefix=""): out = {} - t5_key = "{}model.norm.weight".format(prefix) - if t5_key in state_dict: - out["dtype_llama"] = state_dict[t5_key].dtype + norm_keys = ["{}model.norm.weight".format(prefix), "{}model.layers.0.input_layernorm.weight".format(prefix)] + for norm_key in norm_keys: + if norm_key in state_dict: + out["dtype_llama"] = state_dict[norm_key].dtype + break quant = comfy.utils.detect_layer_quantization(state_dict, prefix) if quant is not None: diff --git a/comfy/text_encoders/jina_clip_2.py b/comfy/text_encoders/jina_clip_2.py new file mode 100644 index 000000000..0cffb6d16 --- /dev/null +++ b/comfy/text_encoders/jina_clip_2.py @@ -0,0 +1,219 @@ +# Jina CLIP v2 and Jina Embeddings v3 both use their modified XLM-RoBERTa architecture. Reference implementation: +# Jina CLIP v2 (both text and vision): https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/modeling_clip.py +# Jina XLM-RoBERTa (text only): http://huggingface.co/jinaai/xlm-roberta-flash-implementation/blob/2b6bc3f30750b3a9648fe9b63448c09920efe9be/modeling_xlm_roberta.py + +from dataclasses import dataclass + +import torch +from torch import nn as nn +from torch.nn import functional as F + +import comfy.model_management +import comfy.ops +from comfy import sd1_clip +from .spiece_tokenizer import SPieceTokenizer + +class JinaClip2Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + # The official NewBie uses max_length=8000, but Jina Embeddings v3 actually supports 8192 + super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='jina_clip_2', tokenizer_class=SPieceTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=False, max_length=8192, min_length=1, pad_token=1, end_token=2, tokenizer_args={"add_bos": True, "add_eos": True}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class JinaClip2TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=JinaClip2Tokenizer, name="jina_clip_2") + +# https://huggingface.co/jinaai/jina-embeddings-v3/blob/343dbf534c76fe845f304fa5c2d1fd87e1e78918/config.json +@dataclass +class XLMRobertaConfig: + vocab_size: int = 250002 + type_vocab_size: int = 1 + hidden_size: int = 1024 + num_hidden_layers: int = 24 + num_attention_heads: int = 16 + rotary_emb_base: float = 20000.0 + intermediate_size: int = 4096 + hidden_act: str = "gelu" + hidden_dropout_prob: float = 0.1 + attention_probs_dropout_prob: float = 0.1 + layer_norm_eps: float = 1e-05 + bos_token_id: int = 0 + eos_token_id: int = 2 + pad_token_id: int = 1 + +class XLMRobertaEmbeddings(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.word_embeddings = ops.Embedding(config.vocab_size, embed_dim, padding_idx=config.pad_token_id, device=device, dtype=dtype) + self.token_type_embeddings = ops.Embedding(config.type_vocab_size, embed_dim, device=device, dtype=dtype) + + def forward(self, input_ids=None, embeddings=None): + if input_ids is not None and embeddings is None: + embeddings = self.word_embeddings(input_ids) + + if embeddings is not None: + token_type_ids = torch.zeros(embeddings.shape[1], device=embeddings.device, dtype=torch.int32) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + return embeddings + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + if seqlen > self._seq_len_cached or self._cos_cached is None or self._cos_cached.device != device or self._cos_cached.dtype != dtype: + self._seq_len_cached = seqlen + t = torch.arange(seqlen, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) + emb = torch.cat((freqs, freqs), dim=-1) + self._cos_cached = emb.cos().to(dtype) + self._sin_cached = emb.sin().to(dtype) + + def forward(self, q, k): + batch, seqlen, heads, head_dim = q.shape + self._update_cos_sin_cache(seqlen, device=q.device, dtype=q.dtype) + + cos = self._cos_cached[:seqlen].view(1, seqlen, 1, head_dim) + sin = self._sin_cached[:seqlen].view(1, seqlen, 1, head_dim) + + def rotate_half(x): + size = x.shape[-1] // 2 + x1, x2 = x[..., :size], x[..., size:] + return torch.cat((-x2, x1), dim=-1) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +class MHA(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = embed_dim // config.num_attention_heads + + self.rotary_emb = RotaryEmbedding(self.head_dim, config.rotary_emb_base, device=device) + self.Wqkv = ops.Linear(embed_dim, 3 * embed_dim, device=device, dtype=dtype) + self.out_proj = ops.Linear(embed_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, x, mask=None, optimized_attention=None): + qkv = self.Wqkv(x) + batch_size, seq_len, _ = qkv.shape + qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + + q, k = self.rotary_emb(q, k) + + # NHD -> HND + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + out = optimized_attention(q, k, v, heads=self.num_heads, mask=mask, skip_reshape=True) + return self.out_proj(out) + +class MLP(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.fc1 = ops.Linear(config.hidden_size, config.intermediate_size, device=device, dtype=dtype) + self.activation = F.gelu + self.fc2 = ops.Linear(config.intermediate_size, config.hidden_size, device=device, dtype=dtype) + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + +class Block(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.mixer = MHA(config, device=device, dtype=dtype, ops=ops) + self.dropout1 = nn.Dropout(config.hidden_dropout_prob) + self.norm1 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + self.dropout2 = nn.Dropout(config.hidden_dropout_prob) + self.norm2 = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + + def forward(self, hidden_states, mask=None, optimized_attention=None): + mixer_out = self.mixer(hidden_states, mask=mask, optimized_attention=optimized_attention) + hidden_states = self.norm1(self.dropout1(mixer_out) + hidden_states) + mlp_out = self.mlp(hidden_states) + hidden_states = self.norm2(self.dropout2(mlp_out) + hidden_states) + return hidden_states + +class XLMRobertaEncoder(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.layers = nn.ModuleList([Block(config, device=device, dtype=dtype, ops=ops) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask=None): + optimized_attention = comfy.ldm.modules.attention.optimized_attention_for_device(hidden_states.device, mask=attention_mask is not None, small_input=True) + for layer in self.layers: + hidden_states = layer(hidden_states, mask=attention_mask, optimized_attention=optimized_attention) + return hidden_states + +class XLMRobertaModel_(nn.Module): + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__() + self.embeddings = XLMRobertaEmbeddings(config, device=device, dtype=dtype, ops=ops) + self.emb_ln = ops.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, device=device, dtype=dtype) + self.emb_drop = nn.Dropout(config.hidden_dropout_prob) + self.encoder = XLMRobertaEncoder(config, device=device, dtype=dtype, ops=ops) + + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, embeds_info=[]): + x = self.embeddings(input_ids=input_ids, embeddings=embeds) + x = self.emb_ln(x) + x = self.emb_drop(x) + + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, 1, attention_mask.shape[-1])) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) + + sequence_output = self.encoder(x, attention_mask=mask) + + # Mean pool, see https://huggingface.co/jinaai/jina-clip-implementation/blob/39e6a55ae971b59bea6e44675d237c99762e7ee2/hf_model.py + pooled_output = None + if attention_mask is None: + pooled_output = sequence_output.mean(dim=1) + else: + attention_mask = attention_mask.to(sequence_output.dtype) + pooled_output = (sequence_output * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=-1, keepdim=True) + + # Intermediate output is not yet implemented, use None for placeholder + return sequence_output, None, pooled_output + +class XLMRobertaModel(nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + self.config = XLMRobertaConfig(**config_dict) + self.model = XLMRobertaModel_(self.config, device=device, dtype=dtype, ops=operations) + self.num_layers = self.config.num_hidden_layers + + def get_input_embeddings(self): + return self.model.embeddings.word_embeddings + + def set_input_embeddings(self, embeddings): + self.model.embeddings.word_embeddings = embeddings + + def forward(self, *args, **kwargs): + return self.model(*args, **kwargs) + +class JinaClip2TextModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, textmodel_json_config={}, model_class=XLMRobertaModel, special_tokens={"start": 0, "end": 2, "pad": 1}, enable_attention_masks=True, return_attention_masks=True, model_options=model_options) + +class JinaClip2TextModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, clip_model=JinaClip2TextModel, name="jina_clip_2", model_options=model_options) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index 0d07ac8c6..9fdea999c 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -1,15 +1,17 @@ import torch import torch.nn as nn from dataclasses import dataclass -from typing import Optional, Any +from typing import Optional, Any, Tuple import math -import logging +from tqdm import tqdm +import comfy.utils from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management +import comfy.ops import comfy.ldm.common_dit +import comfy.clip_model -import comfy.model_management from . import qwen_vl @dataclass @@ -33,6 +35,7 @@ class Llama2Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Mistral3Small24BConfig: @@ -55,6 +58,7 @@ class Mistral3Small24BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_3BConfig: @@ -77,6 +81,103 @@ class Qwen25_3BConfig: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False + +@dataclass +class Qwen3_06BConfig: + vocab_size: int = 151936 + hidden_size: int = 1024 + intermediate_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + stop_tokens = [151643, 151645] + +@dataclass +class Qwen3_06B_ACE15_Config: + vocab_size: int = 151669 + hidden_size: int = 1024 + intermediate_size: int = 3072 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 32768 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + stop_tokens = [151643, 151645] + +@dataclass +class Qwen3_2B_ACE15_lm_Config: + vocab_size: int = 217204 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + stop_tokens = [151643, 151645] + +@dataclass +class Qwen3_4B_ACE15_lm_Config: + vocab_size: int = 217204 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + stop_tokens = [151643, 151645] @dataclass class Qwen3_4BConfig: @@ -99,6 +200,32 @@ class Qwen3_4BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False + stop_tokens = [151643, 151645] + +@dataclass +class Qwen3_8BConfig: + vocab_size: int = 151936 + hidden_size: int = 4096 + intermediate_size: int = 12288 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + lm_head: bool = False + stop_tokens = [151643, 151645] @dataclass class Ovis25_2BConfig: @@ -121,6 +248,7 @@ class Ovis25_2BConfig: k_norm = "gemma3" rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Qwen25_7BVLI_Config: @@ -143,6 +271,7 @@ class Qwen25_7BVLI_Config: k_norm = None rope_scale = None final_norm: bool = True + lm_head: bool = False @dataclass class Gemma2_2B_Config: @@ -166,6 +295,8 @@ class Gemma2_2B_Config: sliding_attention = None rope_scale = None final_norm: bool = True + lm_head: bool = False + stop_tokens = [1] @dataclass class Gemma3_4B_Config: @@ -177,7 +308,7 @@ class Gemma3_4B_Config: num_key_value_heads: int = 4 max_position_embeddings: int = 131072 rms_norm_eps: float = 1e-6 - rope_theta = [10000.0, 1000000.0] + rope_theta = [1000000.0, 10000.0] transformer_type: str = "gemma3" head_dim = 256 rms_norm_add = True @@ -186,9 +317,45 @@ class Gemma3_4B_Config: rope_dims = None q_norm = "gemma3" k_norm = "gemma3" - sliding_attention = [False, False, False, False, False, 1024] - rope_scale = [1.0, 8.0] + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] final_norm: bool = True + lm_head: bool = False + stop_tokens = [1, 106] + +GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + +@dataclass +class Gemma3_4B_Vision_Config(Gemma3_4B_Config): + vision_config = GEMMA3_VISION_CONFIG + mm_tokens_per_image = 256 + +@dataclass +class Gemma3_12B_Config: + vocab_size: int = 262208 + hidden_size: int = 3840 + intermediate_size: int = 15360 + num_hidden_layers: int = 48 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma3" + head_dim = 256 + rms_norm_add = True + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] + final_norm: bool = True + lm_head: bool = False + vision_config = GEMMA3_VISION_CONFIG + mm_tokens_per_image = 256 + stop_tokens = [1, 106] class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): @@ -206,13 +373,6 @@ class RMSNorm(nn.Module): -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): if not isinstance(theta, list): theta = [theta] @@ -241,20 +401,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di else: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) - out.append((cos, sin)) + sin_split = sin.shape[-1] // 2 + out.append((cos, sin[..., : sin_split], -sin[..., sin_split :])) if len(out) == 1: return out[0] return out - def apply_rope(xq, xk, freqs_cis): org_dtype = xq.dtype cos = freqs_cis[0] sin = freqs_cis[1] - q_embed = (xq * cos) + (rotate_half(xq) * sin) - k_embed = (xk * cos) + (rotate_half(xk) * sin) + nsin = freqs_cis[2] + + q_embed = (xq * cos) + q_split = q_embed.shape[-1] // 2 + q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin) + q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin) + + k_embed = (xk * cos) + k_split = k_embed.shape[-1] // 2 + k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin) + k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin) + return q_embed.to(org_dtype), k_embed.to(org_dtype) @@ -288,8 +458,11 @@ class Attention(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + sliding_window: Optional[int] = None, ): batch_size, seq_length, _ = hidden_states.shape + xq = self.q_proj(hidden_states) xk = self.k_proj(hidden_states) xv = self.v_proj(hidden_states) @@ -305,11 +478,35 @@ class Attention(nn.Module): xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis) + present_key_value = None + if past_key_value is not None: + index = 0 + num_tokens = xk.shape[2] + if len(past_key_value) > 0: + past_key, past_value, index = past_key_value + if past_key.shape[2] >= (index + num_tokens): + past_key[:, :, index:index + xk.shape[2]] = xk + past_value[:, :, index:index + xv.shape[2]] = xv + xk = past_key[:, :, :index + xk.shape[2]] + xv = past_value[:, :, :index + xv.shape[2]] + present_key_value = (past_key, past_value, index + num_tokens) + else: + xk = torch.cat((past_key[:, :, :index], xk), dim=2) + xv = torch.cat((past_value[:, :, :index], xv), dim=2) + present_key_value = (xk, xv, index + num_tokens) + else: + present_key_value = (xk, xv, index + num_tokens) + + if sliding_window is not None and xk.shape[2] > sliding_window: + xk = xk[:, :, -sliding_window:] + xv = xv[:, :, -sliding_window:] + attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True) - return self.o_proj(output) + return self.o_proj(output), present_key_value class MLP(nn.Module): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): @@ -340,15 +537,17 @@ class TransformerBlock(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, ) x = residual + x @@ -358,7 +557,7 @@ class TransformerBlock(nn.Module): x = self.mlp(x) x = residual + x - return x + return x, present_key_value class TransformerBlockGemma2(nn.Module): def __init__(self, config: Llama2Config, index, device=None, dtype=None, ops: Any = None): @@ -370,7 +569,7 @@ class TransformerBlockGemma2(nn.Module): self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) - if config.sliding_attention is not None: # TODO: implement. (Not that necessary since models are trained on less than 1024 tokens) + if config.sliding_attention is not None: self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)] else: self.sliding_attention = False @@ -383,11 +582,19 @@ class TransformerBlockGemma2(nn.Module): attention_mask: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None, optimized_attention=None, + past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ): + sliding_window = None if self.transformer_type == 'gemma3': if self.sliding_attention: + sliding_window = self.sliding_attention if x.shape[1] > self.sliding_attention: - logging.warning("Warning: sliding attention not implemented, results may be incorrect") + sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype) + sliding_mask.tril_(diagonal=-self.sliding_attention) + if attention_mask is not None: + attention_mask = attention_mask + sliding_mask + else: + attention_mask = sliding_mask freqs_cis = freqs_cis[1] else: freqs_cis = freqs_cis[0] @@ -395,11 +602,13 @@ class TransformerBlockGemma2(nn.Module): # Self Attention residual = x x = self.input_layernorm(x) - x = self.self_attn( + x, present_key_value = self.self_attn( hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_key_value, + sliding_window=sliding_window, ) x = self.post_attention_layernorm(x) @@ -412,7 +621,7 @@ class TransformerBlockGemma2(nn.Module): x = self.post_feedforward_layernorm(x) x = residual + x - return x + return x, present_key_value class Llama2_(nn.Module): def __init__(self, config, device=None, dtype=None, ops=None): @@ -443,9 +652,10 @@ class Llama2_(nn.Module): else: self.norm = None - # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) + if config.lm_head: + self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) - def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): + def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): if embeds is not None: x = embeds else: @@ -454,8 +664,13 @@ class Llama2_(nn.Module): if self.normalize_in: x *= self.config.hidden_size ** 0.5 + seq_len = x.shape[1] + past_len = 0 + if past_key_values is not None and len(past_key_values) > 0: + past_len = past_key_values[0][2] + if position_ids is None: - position_ids = torch.arange(0, x.shape[1], device=x.device).unsqueeze(0) + position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0) freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, @@ -466,14 +681,16 @@ class Llama2_(nn.Module): mask = None if attention_mask is not None: - mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) - mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) + mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(x.dtype).min / 4) + + if seq_len > 1: + causal_mask = torch.empty(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device).fill_(torch.finfo(x.dtype).min / 4).triu_(1) + if mask is not None: + mask += causal_mask + else: + mask = causal_mask - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) - if mask is not None: - mask += causal_mask - else: - mask = causal_mask optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True) intermediate = None @@ -489,16 +706,27 @@ class Llama2_(nn.Module): elif intermediate_output < 0: intermediate_output = len(self.layers) + intermediate_output + next_key_values = [] for i, layer in enumerate(self.layers): if all_intermediate is not None: if only_layers is None or (i in only_layers): all_intermediate.append(x.unsqueeze(1).clone()) - x = layer( + + past_kv = None + if past_key_values is not None: + past_kv = past_key_values[i] if len(past_key_values) > 0 else [] + + x, current_kv = layer( x=x, attention_mask=mask, freqs_cis=freqs_cis, optimized_attention=optimized_attention, + past_key_value=past_kv, ) + + if current_kv is not None: + next_key_values.append(current_kv) + if i == intermediate_output: intermediate = x.clone() @@ -515,7 +743,45 @@ class Llama2_(nn.Module): if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: intermediate = self.norm(intermediate) - return x, intermediate + if len(next_key_values) > 0: + return x, intermediate, next_key_values + else: + return x, intermediate + + +class Gemma3MultiModalProjector(torch.nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype) + ) + + self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"]) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype)) + return projected_vision_outputs.type_as(vision_outputs) + class BaseLlama: def get_input_embeddings(self): @@ -527,6 +793,122 @@ class BaseLlama: def forward(self, input_ids, *args, **kwargs): return self.model(input_ids, *args, **kwargs) +class BaseGenerate: + def logits(self, x): + input = x[:, -1:] + if hasattr(self.model, "lm_head"): + module = self.model.lm_head + else: + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x + + def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0): + device = embeds.device + model_config = self.model.config + + if stop_tokens is None: + stop_tokens = self.model.config.stop_tokens + + if execution_dtype is None: + if comfy.model_management.should_use_bf16(device): + execution_dtype = torch.bfloat16 + else: + execution_dtype = torch.float32 + embeds = embeds.to(execution_dtype) + + if embeds.ndim == 2: + embeds = embeds.unsqueeze(0) + + past_key_values = [] #kv_cache init + max_cache_len = embeds.shape[1] + max_length + for x in range(model_config.num_hidden_layers): + past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), + torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) + + generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None + + generated_token_ids = [] + pbar = comfy.utils.ProgressBar(max_length) + + # Generation loop + for step in tqdm(range(max_length), desc="Generating tokens"): + x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) + logits = self.logits(x)[:, -1] + next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample) + token_id = next_token[0].item() + generated_token_ids.append(token_id) + + embeds = self.model.embed_tokens(next_token).to(execution_dtype) + pbar.update(1) + + if token_id in stop_tokens: + break + + return generated_token_ids + + def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True): + + if not do_sample or temperature == 0.0: + return torch.argmax(logits, dim=-1, keepdim=True) + + # Sampling mode + if repetition_penalty != 1.0: + for i in range(logits.shape[0]): + for token_id in set(token_history): + logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty + + if temperature != 1.0: + logits = logits / temperature + + if top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = torch.finfo(logits.dtype).min + + if min_p > 0.0: + probs_before_filter = torch.nn.functional.softmax(logits, dim=-1) + top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True) + min_threshold = min_p * top_probs + indices_to_remove = probs_before_filter < min_threshold + logits[indices_to_remove] = torch.finfo(logits.dtype).min + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 0] = False + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) + indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = torch.finfo(logits.dtype).min + + probs = torch.nn.functional.softmax(logits, dim=-1) + + return torch.multinomial(probs, num_samples=1, generator=generator) + +class BaseQwen3: + def logits(self, x): + input = x[:, -1:] + module = self.model.embed_tokens + + offload_stream = None + if module.comfy_cast_weights: + weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True) + else: + weight = self.model.embed_tokens.weight.to(x) + + x = torch.nn.functional.linear(input, weight, None) + + comfy.ops.uncast_bias_weight(module, weight, None, offload_stream) + return x class Llama2(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): @@ -555,7 +937,34 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen3_4B(BaseLlama, torch.nn.Module): +class Qwen3_06B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_06BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_06B_ACE15(BaseLlama, BaseQwen3, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_06B_ACE15_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_2B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_2B_ACE15_lm_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_4B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen3_4BConfig(**config_dict) @@ -564,6 +973,24 @@ class Qwen3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype +class Qwen3_4B_ACE15_lm(BaseLlama, BaseQwen3, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4B_ACE15_lm_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + +class Qwen3_8B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_8BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + class Ovis25_2B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -573,7 +1000,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Qwen25_7BVLI(BaseLlama, torch.nn.Module): +class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Qwen25_7BVLI_Config(**config_dict) @@ -583,6 +1010,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module): self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) self.dtype = dtype + # todo: should this be tied or not? + #self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) + def preprocess_embed(self, embed, device): if embed["type"] == "image": image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) @@ -598,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module): grid = e.get("extra", None) start = e.get("index") if position_ids is None: - position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) + position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long) position_ids[:, :start] = torch.arange(0, start, device=embeds.device) end = e.get("size") + start len_max = int(grid.max()) // 2 start_next = len_max + start - position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) + if attention_mask is not None: + # Assign compact sequential positions to attended tokens only, + # skipping over padding so post-padding tokens aren't inflated. + after_mask = attention_mask[0, end:] + text_positions = after_mask.cumsum(0) - 1 + start_next + offset + position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:]) + else: + position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) position_ids[0, start:end] = start + offset max_d = int(grid[0][1]) // 2 position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] @@ -616,7 +1053,7 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module): return super().forward(x, attention_mask=attention_mask, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=final_layer_norm_intermediate, dtype=dtype, position_ids=position_ids) -class Gemma2_2B(BaseLlama, torch.nn.Module): +class Gemma2_2B(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Gemma2_2B_Config(**config_dict) @@ -625,7 +1062,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype -class Gemma3_4B(BaseLlama, torch.nn.Module): +class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() config = Gemma3_4B_Config(**config_dict) @@ -633,3 +1070,39 @@ class Gemma3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype + +class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_4B_Vision_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations) + self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations) + self.image_size = config.vision_config["image_size"] + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True) + return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None + return None, None + +class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_12B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations) + self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations) + self.dtype = dtype + self.image_size = config.vision_config["image_size"] + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True) + return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None + return None, None diff --git a/comfy/text_encoders/longcat_image.py b/comfy/text_encoders/longcat_image.py new file mode 100644 index 000000000..0962779e3 --- /dev/null +++ b/comfy/text_encoders/longcat_image.py @@ -0,0 +1,199 @@ +import re +import numbers +import torch +from comfy import sd1_clip +from comfy.text_encoders.qwen_image import Qwen25_7BVLITokenizer, Qwen25_7BVLIModel +import logging + +logger = logging.getLogger(__name__) + +QUOTE_PAIRS = [("'", "'"), ('"', '"'), ("\u2018", "\u2019"), ("\u201c", "\u201d")] +QUOTE_PATTERN = "|".join( + [ + re.escape(q1) + r"[^" + re.escape(q1 + q2) + r"]*?" + re.escape(q2) + for q1, q2 in QUOTE_PAIRS + ] +) +WORD_INTERNAL_QUOTE_RE = re.compile(r"[a-zA-Z]+'[a-zA-Z]+") + + +def split_quotation(prompt): + matches = WORD_INTERNAL_QUOTE_RE.findall(prompt) + mapping = [] + for i, word_src in enumerate(set(matches)): + word_tgt = "longcat_$##$_longcat" * (i + 1) + prompt = prompt.replace(word_src, word_tgt) + mapping.append((word_src, word_tgt)) + + parts = re.split(f"({QUOTE_PATTERN})", prompt) + result = [] + for part in parts: + for word_src, word_tgt in mapping: + part = part.replace(word_tgt, word_src) + if not part: + continue + is_quoted = bool(re.match(QUOTE_PATTERN, part)) + result.append((part, is_quoted)) + return result + + +class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.max_length = 512 + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + parts = split_quotation(text) + all_tokens = [] + for part_text, is_quoted in parts: + if is_quoted: + for char in part_text: + ids = self.tokenizer(char, add_special_tokens=False)["input_ids"] + all_tokens.extend(ids) + else: + ids = self.tokenizer(part_text, add_special_tokens=False)["input_ids"] + all_tokens.extend(ids) + + if len(all_tokens) > self.max_length: + all_tokens = all_tokens[: self.max_length] + logger.warning(f"Truncated prompt to {self.max_length} tokens") + + output = [(t, 1.0) for t in all_tokens] + # Pad to max length + self.pad_tokens(output, self.max_length - len(output)) + return [output] + + +IMAGE_PAD_TOKEN_ID = 151655 + +class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): + T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n" + EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>" + SUFFIX = "<|im_end|>\n<|im_start|>assistant\n" + + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, + name="qwen25_7b", + tokenizer=LongCatImageBaseTokenizer, + ) + + def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs): + skip_template = False + if text.startswith("<|im_start|>"): + skip_template = True + if text.startswith("<|start_header_id|>"): + skip_template = True + if text == "": + text = " " + + base_tok = getattr(self, "qwen25_7b") + if skip_template: + tokens = super().tokenize_with_weights( + text, return_word_ids=return_word_ids, disable_weights=True, **kwargs + ) + else: + has_images = images is not None and len(images) > 0 + template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX + + prefix_ids = base_tok.tokenizer( + template_prefix, add_special_tokens=False + )["input_ids"] + suffix_ids = base_tok.tokenizer( + self.SUFFIX, add_special_tokens=False + )["input_ids"] + + prompt_tokens = base_tok.tokenize_with_weights( + text, return_word_ids=return_word_ids, **kwargs + ) + prompt_pairs = prompt_tokens[0] + + prefix_pairs = [(t, 1.0) for t in prefix_ids] + suffix_pairs = [(t, 1.0) for t in suffix_ids] + + combined = prefix_pairs + prompt_pairs + suffix_pairs + + if has_images: + embed_count = 0 + for i in range(len(combined)): + if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images): + combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1]) + embed_count += 1 + + tokens = {"qwen25_7b": [combined]} + + return tokens + + +class LongCatImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__( + device=device, + dtype=dtype, + name="qwen25_7b", + clip_model=Qwen25_7BVLIModel, + model_options=model_options, + ) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 + + if template_end == -1: + template_end = 0 + + suffix_start = None + for i in range(len(tok_pairs) - 1, -1, -1): + elem = tok_pairs[i][0] + if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral): + if elem == 151645: + suffix_start = i + break + + out = out[:, template_end:] + + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") + + if suffix_start is not None: + suffix_len = len(tok_pairs) - suffix_start + if suffix_len > 0 and out.shape[1] > suffix_len: + out = out[:, :-suffix_len] + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, :-suffix_len] + if extra["attention_mask"].sum() == torch.numel( + extra["attention_mask"] + ): + extra.pop("attention_mask") + + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class LongCatImageTEModel_(LongCatImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + + return LongCatImageTEModel_ diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 48ea67e67..5e1273c6e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -1,7 +1,12 @@ from comfy import sd1_clip import os from transformers import T5TokenizerFast +from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.genmo +import torch +import comfy.utils +import math +import itertools class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -16,3 +21,247 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): def ltxv_te(*args, **kwargs): return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) + + +class Gemma3_Tokenizer(): + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + + def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs): + self.llama_template = "system\nYou are a helpful assistant.\nuser\n{}\nmodel\n" + self.llama_template_images = "system\nYou are a helpful assistant.\nuser\n\n{}\n\nmodel\n" + + if image is None: + images = [] + else: + samples = image.movedim(-1, 1) + total = int(896 * 896) + + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1) + images = [s[:, :, :, :3]] + + if text.startswith(''): + skip_template = True + + if skip_template: + llama_text = text + else: + if llama_template is None: + if len(images) > 0: + llama_text = self.llama_template_images.format(text) + else: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + text_tokens = super().tokenize_with_weights(llama_text, return_word_ids) + + if len(images) > 0: + embed_count = 0 + for r in text_tokens: + for i, token in enumerate(r): + if token[0] == 262144 and embed_count < len(images): + r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:] + embed_count += 1 + return text_tokens + +class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + special_tokens = {"": 262144, "": 106} + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1024, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) + + +class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) + + +class Gemma3_12BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + self.dtypes = set() + self.dtypes.add(dtype) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + tokens_only = [[t[0] for t in b] for b in tokens] + embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) + comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is + +class DualLinearProjection(torch.nn.Module): + def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None): + super().__init__() + self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device) + self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device) + + def forward(self, x): + source_dim = x.shape[-1] + x = x.movedim(1, -1) + x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2) + + video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim)) + audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim)) + return torch.cat((video, audio), dim=-1) + +class LTXAVTEModel(torch.nn.Module): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}): + super().__init__() + self.dtypes = set() + self.dtypes.add(dtype) + self.compat_mode = False + self.text_projection_type = text_projection_type + + self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) + self.dtypes.add(dtype_llama) + + operations = self.gemma3_12b.operations # TODO + + if self.text_projection_type == "single_linear": + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + elif self.text_projection_type == "dual_linear": + self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations) + + + def enable_compat_mode(self): # TODO: remove + from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector + operations = self.gemma3_12b.operations + dtype = self.text_embedding_projection.weight.dtype + device = self.text_embedding_projection.weight.device + self.audio_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + self.video_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + self.compat_mode = True + + def set_clip_options(self, options): + self.execution_device = options.get("execution_device", self.execution_device) + self.gemma3_12b.set_clip_options(options) + + def reset_clip_options(self): + self.gemma3_12b.reset_clip_options() + self.execution_device = None + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs["gemma3_12b"] + + out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) + out = out[:, :, -torch.sum(extra["attention_mask"]).item():] + out_device = out.device + if comfy.model_management.should_use_bf16(self.execution_device): + out = out.to(device=self.execution_device, dtype=torch.bfloat16) + + if self.text_projection_type == "single_linear": + out = out.movedim(1, -1).to(self.execution_device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) + + if self.compat_mode: + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + extra = {} + else: + extra = {"unprocessed_ltxav_embeds": True} + elif self.text_projection_type == "dual_linear": + out = self.text_embedding_projection(out) + extra = {"unprocessed_ltxav_embeds": True} + + return out.to(device=out_device, dtype=torch.float), pooled, extra + + def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): + return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) + + def load_sd(self, sd): + if "model.layers.47.self_attn.q_norm.weight" in sd: + return self.gemma3_12b.load_sd(sd) + else: + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True) + if len(sdo) == 0: + sdo = sd + + missing_all = [] + unexpected_all = [] + + for prefix, component in [("text_embedding_projection.", self.text_embedding_projection)]: + component_sd = {k.replace(prefix, ""): v for k, v in sdo.items() if k.startswith(prefix)} + if component_sd: + missing, unexpected = component.load_state_dict(component_sd, strict=False, assign=getattr(self, "can_assign_sd", False)) + missing_all.extend([f"{prefix}{k}" for k in missing]) + unexpected_all.extend([f"{prefix}{k}" for k in unexpected]) + + if "model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.2.attn1.to_q.bias" not in sd: # TODO: remove + ww = sd.get("model.diffusion_model.audio_embeddings_connector.transformer_1d_blocks.0.attn1.to_q.bias", None) + if ww is not None: + if ww.shape[0] == 3840: + self.enable_compat_mode() + sdv = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.video_embeddings_connector.": ""}, filter_keys=True) + self.video_embeddings_connector.load_state_dict(sdv, strict=False, assign=getattr(self, "can_assign_sd", False)) + sda = comfy.utils.state_dict_prefix_replace(sd, {"model.diffusion_model.audio_embeddings_connector.": ""}, filter_keys=True) + self.audio_embeddings_connector.load_state_dict(sda, strict=False, assign=getattr(self, "can_assign_sd", False)) + + return (missing_all, unexpected_all) + + def memory_estimation_function(self, token_weight_pairs, device=None): + constant = 6.0 + if comfy.model_management.should_use_bf16(device): + constant /= 2.0 + + token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) + m = min([sum(1 for _ in itertools.takewhile(lambda x: x[0] == 0, sub)) for sub in token_weight_pairs]) + + num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) - m + num_tokens = max(num_tokens, 642) + return num_tokens * constant * 1024 * 1024 + +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"): + class LTXAVTEModel_(LTXAVTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options) + return LTXAVTEModel_ + + +def sd_detect(state_dict_list, prefix=""): + for sd in state_dict_list: + if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd: + return {"text_projection_type": "dual_linear"} + if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd: + return {"text_projection_type": "single_linear"} + return {} + + +def gemma3_te(dtype_llama=None, llama_quantization_metadata=None): + class Gemma3_12BModel_(Gemma3_12BModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Gemma3_12BModel_ diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index 7a6cfdab2..01ebdfabe 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -1,23 +1,23 @@ from comfy import sd1_clip from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.llama - +from comfy.text_encoders.lt import Gemma3_Tokenizer +import comfy.utils class Gemma2BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + special_tokens = {"": 107} + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} -class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): +class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) - - def state_dict(self): - return {"spiece_model": self.tokenizer.serialize_model()} + special_tokens = {"": 262144, "": 106} + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data) class LuminaTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -33,8 +33,27 @@ class Gemma2_2BModel(sd1_clip.SDClipModel): class Gemma3_4BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) +class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def process_tokens(self, tokens, device): + embeds, _, _, embeds_info = super().process_tokens(tokens, device) + comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) + return embeds + class LuminaModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel): super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) @@ -45,6 +64,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b model = Gemma2_2BModel elif model_type == "gemma3_4b": model = Gemma3_4BModel + elif model_type == "gemma3_4b_vision": + model = Gemma3_4B_Vision_Model class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): diff --git a/comfy/text_encoders/newbie.py b/comfy/text_encoders/newbie.py new file mode 100644 index 000000000..db2324576 --- /dev/null +++ b/comfy/text_encoders/newbie.py @@ -0,0 +1,62 @@ +import torch + +import comfy.model_management +import comfy.text_encoders.jina_clip_2 +import comfy.text_encoders.lumina2 + +class NewBieTokenizer: + def __init__(self, embedding_directory=None, tokenizer_data={}): + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BTokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["gemma_spiece_model"]}) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2Tokenizer(embedding_directory=embedding_directory, tokenizer_data={"spiece_model": tokenizer_data["jina_spiece_model"]}) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = {} + out["gemma"] = self.gemma.tokenize_with_weights(text, return_word_ids, **kwargs) + out["jina"] = self.jina.tokenize_with_weights(text, return_word_ids, **kwargs) + return out + + def untokenize(self, token_weight_pair): + raise NotImplementedError + + def state_dict(self): + return {} + +class NewBieTEModel(torch.nn.Module): + def __init__(self, dtype_gemma=None, device="cpu", dtype=None, model_options={}): + super().__init__() + dtype_gemma = comfy.model_management.pick_weight_dtype(dtype_gemma, dtype, device) + self.gemma = comfy.text_encoders.lumina2.Gemma3_4BModel(device=device, dtype=dtype_gemma, model_options=model_options) + self.jina = comfy.text_encoders.jina_clip_2.JinaClip2TextModel(device=device, dtype=dtype, model_options=model_options) + self.dtypes = {dtype, dtype_gemma} + + def set_clip_options(self, options): + self.gemma.set_clip_options(options) + self.jina.set_clip_options(options) + + def reset_clip_options(self): + self.gemma.reset_clip_options() + self.jina.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs_gemma = token_weight_pairs["gemma"] + token_weight_pairs_jina = token_weight_pairs["jina"] + + gemma_out, gemma_pooled, gemma_extra = self.gemma.encode_token_weights(token_weight_pairs_gemma) + jina_out, jina_pooled, jina_extra = self.jina.encode_token_weights(token_weight_pairs_jina) + + return gemma_out, jina_pooled, gemma_extra + + def load_sd(self, sd): + if "model.layers.0.self_attn.q_norm.weight" in sd: + return self.gemma.load_sd(sd) + else: + return self.jina.load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class NewBieTEModel_(NewBieTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + super().__init__(dtype_gemma=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return NewBieTEModel_ diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index 5754424d2..2cc0867c3 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -61,6 +61,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return OvisTEModel_ diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index e5e5f18be..51c6e50c7 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None): if t5_quantization_metadata is not None: model_options = model_options.copy() model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata - if dtype is None: + if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) return PixArtTEModel_ diff --git a/comfy/text_encoders/qwen_vl.py b/comfy/text_encoders/qwen_vl.py index 3b18ce730..98c350a12 100644 --- a/comfy/text_encoders/qwen_vl.py +++ b/comfy/text_encoders/qwen_vl.py @@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module): hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) hidden_states = self.merger(hidden_states) + # Potentially important for spatially precise edits. This is present in the HF implementation. + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] return hidden_states diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py index caccb3ca2..099d8d2d9 100644 --- a/comfy/text_encoders/spiece_tokenizer.py +++ b/comfy/text_encoders/spiece_tokenizer.py @@ -6,9 +6,10 @@ class SPieceTokenizer: def from_pretrained(path, **kwargs): return SPieceTokenizer(path, **kwargs) - def __init__(self, tokenizer_path, add_bos=False, add_eos=True): + def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None): self.add_bos = add_bos self.add_eos = add_eos + self.special_tokens = special_tokens import sentencepiece if torch.is_tensor(tokenizer_path): tokenizer_path = tokenizer_path.numpy().tobytes() @@ -27,8 +28,32 @@ class SPieceTokenizer: return out def __call__(self, string): + if self.special_tokens is not None: + import re + special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys()) + if special_tokens_pattern and re.search(special_tokens_pattern, string): + parts = re.split(f'({special_tokens_pattern})', string) + result = [] + for part in parts: + if not part: + continue + if part in self.special_tokens: + result.append(self.special_tokens[part]) + else: + encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False) + result.extend(encoded) + return {"input_ids": result} + out = self.tokenizer.encode(string) return {"input_ids": out} + def decode(self, token_ids, skip_special_tokens=False): + + if skip_special_tokens and self.special_tokens: + special_token_ids = set(self.special_tokens.values()) + token_ids = [tid for tid in token_ids if tid not in special_token_ids] + + return self.tokenizer.decode(token_ids) + def serialize_model(self): return torch.ByteTensor(list(self.tokenizer.serialized_model_proto())) diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index 19adde0b7..33b7cf594 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -6,7 +6,7 @@ import os class Qwen3Tokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) class ZImageTokenizer(sd1_clip.SD1Tokenizer): @@ -40,6 +40,7 @@ def te(dtype_llama=None, llama_quantization_metadata=None): if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, dtype=dtype, model_options=model_options) return ZImageTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 8d4e2b445..78c491b98 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -20,40 +20,104 @@ import torch import math import struct -import comfy.checkpoint_pickle +import ctypes +import os +import comfy.memory_management import safetensors.torch import numpy as np from PIL import Image import logging import itertools from torch.nn.functional import interpolate +from tqdm.auto import trange from einops import rearrange from comfy.cli_args import args import json +import time +import threading +import warnings MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap -ALWAYS_SAFE_LOAD = False -if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated + +if True: # ckpt/pt file whitelist for safe loading of old sd files class ModelCheckpoint: pass ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint" def scalar(*args, **kwargs): - from numpy.core.multiarray import scalar as sc - return sc(*args, **kwargs) + return None scalar.__module__ = "numpy.core.multiarray" from numpy import dtype from numpy.dtypes import Float64DType - from _codecs import encode + + def encode(*args, **kwargs): # no longer necessary on newer torch + return None + encode.__module__ = "_codecs" torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode]) - ALWAYS_SAFE_LOAD = True logging.info("Checkpoint files will always be loaded safely.") -else: - logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.") + + +# Current as of safetensors 0.7.0 +_TYPES = { + "F64": torch.float64, + "F32": torch.float32, + "F16": torch.float16, + "BF16": torch.bfloat16, + "I64": torch.int64, + "I32": torch.int32, + "I16": torch.int16, + "I8": torch.int8, + "U8": torch.uint8, + "BOOL": torch.bool, + "F8_E4M3": torch.float8_e4m3fn, + "F8_E5M2": torch.float8_e5m2, + "C64": torch.complex64, + + "U64": torch.uint64, + "U32": torch.uint32, + "U16": torch.uint16, +} + +def load_safetensors(ckpt): + import comfy_aimdo.model_mmap + + f = open(ckpt, "rb", buffering=0) + model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) + file_size = os.path.getsize(ckpt) + mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) + + header_size = struct.unpack(" 0: message = e.args[0] @@ -83,11 +152,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): if MMAP_TORCH_FILES: torch_args["mmap"] = True - if safe_load or ALWAYS_SAFE_LOAD: - pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) - else: - logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt)) - pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle) + pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args) + if "state_dict" in pl_sd: sd = pl_sd["state_dict"] else: @@ -610,10 +676,18 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.net.0.proj.bias": "txt_mlp.0.bias", "ff_context.net.2.weight": "txt_mlp.2.weight", "ff_context.net.2.bias": "txt_mlp.2.bias", - "attn.norm_q.weight": "img_attn.norm.query_norm.scale", - "attn.norm_k.weight": "img_attn.norm.key_norm.scale", - "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", - "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + "ff.linear_in.weight": "img_mlp.0.weight", # LyCoris LoKr + "ff.linear_in.bias": "img_mlp.0.bias", + "ff.linear_out.weight": "img_mlp.2.weight", + "ff.linear_out.bias": "img_mlp.2.bias", + "ff_context.linear_in.weight": "txt_mlp.0.weight", + "ff_context.linear_in.bias": "txt_mlp.0.bias", + "ff_context.linear_out.weight": "txt_mlp.2.weight", + "ff_context.linear_out.bias": "txt_mlp.2.bias", + "attn.norm_q.weight": "img_attn.norm.query_norm.weight", + "attn.norm_k.weight": "img_attn.norm.key_norm.weight", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight", } for k in block_map: @@ -636,8 +710,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "norm.linear.bias": "modulation.lin.bias", "proj_out.weight": "linear2.weight", "proj_out.bias": "linear2.bias", - "attn.norm_q.weight": "norm.query_norm.scale", - "attn.norm_k.weight": "norm.key_norm.scale", + "attn.norm_q.weight": "norm.query_norm.weight", + "attn.norm_k.weight": "norm.key_norm.weight", + "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 + "attn.to_out.weight": "linear2.weight", # Flux 2 } for k in block_map: @@ -805,20 +881,35 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): ATTR_UNSET={} -def set_attr(obj, attr, value): +def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1], ATTR_UNSET) + return obj, attrs[-1] + +def set_attr(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: - delattr(obj, attrs[-1]) + delattr(obj, name) else: - setattr(obj, attrs[-1], value) + setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): + # Clone inference tensors (created under torch.inference_mode) since + # their version counter is frozen and nn.Parameter() cannot wrap them. + if (not torch.is_inference_mode_enabled()) and value.is_inference(): + value = value.clone() return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) +def set_attr_buffer(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) + persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) + obj.register_buffer(name, value, persistent=persistent) + return prev + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") @@ -928,7 +1019,9 @@ def bislerp(samples, width, height): return result.to(orig_dtype) def lanczos(samples, width, height): - images = [Image.fromarray(np.clip(255. * image.movedim(0, -1).cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] + #the below API is strict and expects grayscale to be squeezed + samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1) + images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] result = torch.stack(images) @@ -1042,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am pbar.update(1) continue - out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) + out = output[b:b+1].zero_() + out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device) positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] @@ -1058,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am upscaled.append(round(get_pos(d, pos))) ps = function(s_in).to(output_device) - mask = torch.ones_like(ps) + mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device) for d in range(2, dims + 2): feather = round(get_scale(d - 2, overlap[d - 2])) @@ -1081,12 +1174,38 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am if pbar is not None: pbar.update(1) - output[b:b+1] = out/out_div + out.div_(out_div) return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar) +def model_trange(*args, **kwargs): + if not comfy.memory_management.aimdo_enabled: + return trange(*args, **kwargs) + + pbar = trange(*args, **kwargs, smoothing=1.0) + pbar._i = 0 + pbar.set_postfix_str(" Model Initializing ... ") + + _update = pbar.update + + def warmup_update(n=1): + pbar._i += 1 + if pbar._i == 1: + pbar.i1_time = time.time() + pbar.set_postfix_str(" Model Initialization complete! ") + elif pbar._i == 2: + #bring forward the effective start time based the the diff between first and second iteration + #to attempt to remove load overhead from the final step rate estimate. + pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time) + pbar.set_postfix_str("") + + _update(n) + + pbar.update = warmup_update + return pbar + PROGRESS_BAR_ENABLED = True def set_progress_bar_enabled(enabled): global PROGRESS_BAR_ENABLED @@ -1097,6 +1216,10 @@ def set_progress_bar_global_hook(function): global PROGRESS_BAR_HOOK PROGRESS_BAR_HOOK = function +# Throttle settings for progress bar updates to reduce WebSocket flooding +PROGRESS_THROTTLE_MIN_INTERVAL = 0.1 # 100ms minimum between updates +PROGRESS_THROTTLE_MIN_PERCENT = 0.5 # 0.5% minimum progress change + class ProgressBar: def __init__(self, total, node_id=None): global PROGRESS_BAR_HOOK @@ -1104,6 +1227,8 @@ class ProgressBar: self.current = 0 self.hook = PROGRESS_BAR_HOOK self.node_id = node_id + self._last_update_time = 0.0 + self._last_sent_value = -1 def update_absolute(self, value, total=None, preview=None): if total is not None: @@ -1112,7 +1237,29 @@ class ProgressBar: value = self.total self.current = value if self.hook is not None: - self.hook(self.current, self.total, preview, node_id=self.node_id) + current_time = time.perf_counter() + is_first = (self._last_sent_value < 0) + is_final = (value >= self.total) + has_preview = (preview is not None) + + # Always send immediately for previews, first update, or final update + if has_preview or is_first or is_final: + self.hook(self.current, self.total, preview, node_id=self.node_id) + self._last_update_time = current_time + self._last_sent_value = value + return + + # Apply throttling for regular progress updates + if self.total > 0: + percent_changed = ((value - max(0, self._last_sent_value)) / self.total) * 100 + else: + percent_changed = 100 + time_elapsed = current_time - self._last_update_time + + if time_elapsed >= PROGRESS_THROTTLE_MIN_INTERVAL and percent_changed >= PROGRESS_THROTTLE_MIN_PERCENT: + self.hook(self.current, self.total, preview, node_id=self.node_id) + self._last_update_time = current_time + self._last_sent_value = value def update(self, value): self.update_absolute(self.current + value) @@ -1198,7 +1345,7 @@ def unpack_latents(combined_latent, latent_shapes): combined_latent = combined_latent[:, :, cut:] output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) else: - output_tensors = combined_latent + output_tensors = [combined_latent] return output_tensors def detect_layer_quantization(state_dict, prefix): @@ -1230,6 +1377,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): out_sd = {} layers = {} for k in list(state_dict.keys()): + if k == scaled_fp8_key: + continue if not k.startswith(model_prefix): out_sd[k] = state_dict[k] continue @@ -1265,3 +1414,42 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}): state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8) return state_dict, metadata + +def string_to_seed(data): + crc = 0xFFFFFFFF + for byte in data: + if isinstance(byte, str): + byte = ord(byte) + crc ^= byte + for _ in range(8): + if crc & 1: + crc = (crc >> 1) ^ 0xEDB88320 + else: + crc >>= 1 + return crc ^ 0xFFFFFFFF + +def deepcopy_list_dict(obj, memo=None): + if memo is None: + memo = {} + + obj_id = id(obj) + if obj_id in memo: + return memo[obj_id] + + if isinstance(obj, dict): + res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()} + elif isinstance(obj, list): + res = [deepcopy_list_dict(i, memo) for i in obj] + else: + res = obj + + memo[obj_id] = res + return res + +def normalize_image_embeddings(embeds, embeds_info, scale_factor): + """Normalize image embeddings to match text embedding scale""" + for info in embeds_info: + if info.get("type") == "image": + start_idx = info["index"] + end_idx = start_idx + info["size"] + embeds[:, start_idx:end_idx, :] /= scale_factor diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py index b40f920e4..b9fa8d5cf 100644 --- a/comfy/weight_adapter/__init__.py +++ b/comfy/weight_adapter/__init__.py @@ -5,6 +5,11 @@ from .lokr import LoKrAdapter from .glora import GLoRAAdapter from .oft import OFTAdapter from .boft import BOFTAdapter +from .bypass import ( + BypassInjectionManager, + BypassForwardHook, + create_bypass_injections_from_patches, +) adapters: list[type[WeightAdapterBase]] = [ @@ -31,4 +36,7 @@ __all__ = [ "WeightAdapterTrainBase", "adapters", "adapter_maps", + "BypassInjectionManager", + "BypassForwardHook", + "create_bypass_injections_from_patches", ] + [a.__name__ for a in adapters] diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py index 43644b106..d352e066b 100644 --- a/comfy/weight_adapter/base.py +++ b/comfy/weight_adapter/base.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional import torch import torch.nn as nn @@ -7,12 +7,35 @@ import comfy.model_management class WeightAdapterBase: + """ + Base class for weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT/BOFT: g = transform, h = 0 + """ + name: str loaded_keys: set[str] weights: list[torch.Tensor] + # Attributes set by bypass system + multiplier: float = 1.0 + shape: tuple = None # (out_features, in_features) or (out_ch, in_ch, *kernel) + @classmethod - def load(cls, x: str, lora: dict[str, torch.Tensor], alpha: float, dora_scale: torch.Tensor) -> Optional["WeightAdapterBase"]: + def load( + cls, + x: str, + lora: dict[str, torch.Tensor], + alpha: float, + dora_scale: torch.Tensor, + ) -> Optional["WeightAdapterBase"]: raise NotImplementedError def to_train(self) -> "WeightAdapterTrainBase": @@ -26,6 +49,12 @@ class WeightAdapterBase: """ raise NotImplementedError + def calculate_shape( + self, + key + ): + return None + def calculate_weight( self, weight, @@ -39,18 +68,202 @@ class WeightAdapterBase: ): raise NotImplementedError + # ===== Bypass Mode Methods ===== + # + # IMPORTANT: Bypass mode is designed for quantized models where original weights + # may not be accessible in a usable format. Therefore, h() and bypass_forward() + # do NOT take org_weight as a parameter. All necessary information (out_channels, + # in_channels, conv params, etc.) is provided via attributes set by BypassForwardHook. + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT/BOFT), returns zeros. + + Note: + This method does NOT access original model weights. Bypass mode is + designed for quantized models where weights may not be in a usable format. + All shape info comes from module attributes set by BypassForwardHook. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Reference: LyCORIS LoConModule.bypass_forward_diff + """ + # Default: no additive component (for OFT/BOFT) + # Simply return zeros matching base_out shape + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT/BOFT override this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + + Reference: LyCORIS OFTModule applies orthogonal transform here + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Note: + This method does NOT take org_weight/org_bias parameters. Bypass mode + is designed for quantized models where weights may not be accessible. + The original forward function handles weight access internally. + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + + Reference: LyCORIS LoConModule.bypass_forward + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + class WeightAdapterTrainBase(nn.Module): - # We follow the scheme of PR #7032 + """ + Base class for trainable weight adapters (LoRA, LoHa, LoKr, OFT, etc.) + + Bypass Mode: + All adapters follow the pattern: bypass(f)(x) = g(f(x) + h(x)) + + - h(x): Additive component (LoRA path). Returns delta to add to base output. + - g(y): Output transformation. Applied after base + h(x). + + For LoRA/LoHa/LoKr: g = identity, h = adapter(x) + For OFT: g = transform, h = 0 + + Note: + Unlike WeightAdapterBase, TrainBase classes have simplified weight formats + with fewer branches (e.g., LoKr only has w1/w2, not w1_a/w1_b decomposition). + + We follow the scheme of PR #7032 + """ + + # Attributes set by bypass system (BypassForwardHook) + # These are set before h()/g()/bypass_forward() are called + multiplier: float = 1.0 + is_conv: bool = False + conv_dim: int = 0 # 0=linear, 1=conv1d, 2=conv2d, 3=conv3d + kw_dict: dict = {} # Conv kwargs: stride, padding, dilation, groups + kernel_size: tuple = () + in_channels: int = None + out_channels: int = None + def __init__(self): super().__init__() def __call__(self, w): """ - w: The original weight tensor to be modified. + Weight modification mode: returns modified weight. + + Args: + w: The original weight tensor to be modified. + + Returns: + Modified weight tensor. """ raise NotImplementedError + # ===== Bypass Mode Methods ===== + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component: h(x, base_out) + + Computes the adapter's contribution to be added to base forward output. + For adapters that only transform output (OFT), returns zeros. + + Args: + x: Input tensor + base_out: Output from base forward f(x), can be used for shape reference + + Returns: + Delta tensor to add to base output. Shape matches base output. + + Subclasses should override this method. + """ + raise NotImplementedError( + f"{self.__class__.__name__}.h() not implemented. " + "Subclasses must implement h() for bypass mode." + ) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation: g(y) + + Applied after base forward + h(x). For most adapters this is identity. + OFT overrides this to apply orthogonal transformation. + + Args: + y: Combined output (base + h(x)) + + Returns: + Transformed output + """ + # Default: identity (for LoRA/LoHa/LoKr) + return y + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + Full bypass forward: g(f(x) + h(x, f(x))) + + Args: + org_forward: Original module forward function + x: Input tensor + *args, **kwargs: Additional arguments for org_forward + + Returns: + Output with adapter applied in bypass mode + """ + # Base forward: f(x) + base_out = org_forward(x, *args, **kwargs) + + # Additive component: h(x, base_out) - base_out provided for shape reference + h_out = self.h(x, base_out) + + # Output transformation: g(base + h) + return self.g(base_out + h_out) + def passive_memory_usage(self): raise NotImplementedError("passive_memory_usage is not implemented") @@ -59,8 +272,12 @@ class WeightAdapterTrainBase(nn.Module): return self.passive_memory_usage() -def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function): - dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype) +def weight_decompose( + dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function +): + dora_scale = comfy.model_management.cast_to_device( + dora_scale, weight.device, intermediate_dtype + ) lora_diff *= alpha weight_calc = weight + function(lora_diff).type(weight.dtype) @@ -106,10 +323,14 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten the original tensor will be truncated in that dimension. """ if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]): - raise ValueError("The new shape must be larger than the original tensor in all dimensions") + raise ValueError( + "The new shape must be larger than the original tensor in all dimensions" + ) if len(new_shape) != len(tensor.shape): - raise ValueError("The new shape must have the same number of dimensions as the original tensor") + raise ValueError( + "The new shape must have the same number of dimensions as the original tensor" + ) # Create a new tensor filled with zeros padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py index b2a2f1bd4..02a8dc130 100644 --- a/comfy/weight_adapter/boft.py +++ b/comfy/weight_adapter/boft.py @@ -62,9 +62,13 @@ class BOFTAdapter(WeightAdapterBase): alpha = v[2] dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) boft_m, block_num, boft_b, *_ = blocks.shape @@ -74,7 +78,7 @@ class BOFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(-1, -2) normed_q = q - if alpha > 0: # alpha in boft/bboft is for constraint + if alpha > 0: # alpha in boft/bboft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm @@ -83,13 +87,13 @@ class BOFTAdapter(WeightAdapterBase): r = r.to(weight) inp = org = weight - r_b = boft_b//2 + r_b = boft_b // 2 for i in range(boft_m): bi = r[i] g = 2 k = 2**i * r_b if strength != 1: - bi = bi * strength + (1-strength) * I + bi = bi * strength + (1 - strength) * I inp = ( inp.unflatten(0, (-1, g, k)) .transpose(1, 2) @@ -98,18 +102,117 @@ class BOFTAdapter(WeightAdapterBase): ) inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp) inp = ( - inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2) + inp.flatten(0, 1) + .unflatten(0, (-1, k, g)) + .transpose(1, 2) + .flatten(0, 2) ) if rescale is not None: inp = inp * rescale lora_diff = inp - org - lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype) + lora_diff = comfy.model_management.cast_to_device( + lora_diff, weight.device, intermediate_dtype + ) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrices(self, device, dtype): + """Compute the orthogonal rotation matrices R from BOFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + boft_m, block_num, boft_b, _ = blocks.shape + I = torch.eye(boft_b, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(-1, -2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, boft_m, boft_b + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for BOFT: applies butterfly orthogonal transform. + + BOFT uses multiple stages of butterfly-structured orthogonal transforms. + + Reference: LyCORIS ButterflyOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, boft_m, boft_b = self._get_orthogonal_matrices(y.device, y.dtype) + r_b = boft_b // 2 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(boft_b, device=y.device, dtype=y.dtype) + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # Apply butterfly transform stages + inp = y + for i in range(boft_m): + bi = r[i] # (block_num, boft_b, boft_b) + g = 2 + k = 2**i * r_b + + # Interpolate with identity based on multiplier + if multiplier != 1: + bi = bi * multiplier + (1 - multiplier) * I + + # Reshape for butterfly: unflatten last dim, transpose, flatten, unflatten + inp = ( + inp.unflatten(-1, (-1, g, k)) + .transpose(-2, -1) + .flatten(-3) + .unflatten(-1, (-1, boft_b)) + ) + # Apply block-diagonal orthogonal transform + inp = torch.einsum("b i j, ... b j -> ... b i", bi, inp) + # Reshape back + inp = ( + inp.flatten(-2).unflatten(-1, (-1, k, g)).transpose(-2, -1).flatten(-3) + ) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + inp = inp * rescale.transpose(0, -1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + inp = inp.transpose(1, -1) + + return inp diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py new file mode 100644 index 000000000..b9d5ec7d9 --- /dev/null +++ b/comfy/weight_adapter/bypass.py @@ -0,0 +1,441 @@ +""" +Bypass mode implementation for weight adapters (LoRA, LoKr, LoHa, etc.) + +Bypass mode applies adapters during forward pass without modifying base weights: + bypass(f)(x) = g(f(x) + h(x)) + +Where: + - f(x): Original layer forward + - h(x): Additive component from adapter (LoRA path) + - g(y): Output transformation (identity for most adapters) + +This is useful for: + - Training with gradient checkpointing + - Avoiding weight modifications when weights are offloaded + - Supporting multiple adapters with different strengths dynamically +""" + +import logging +from typing import Optional, Union + +import torch +import torch.nn as nn + +import comfy.model_management +from .base import WeightAdapterBase, WeightAdapterTrainBase +from comfy.patcher_extension import PatcherInjection + +# Type alias for adapters that support bypass mode +BypassAdapter = Union[WeightAdapterBase, WeightAdapterTrainBase] + + +def get_module_type_info(module: nn.Module) -> dict: + """ + Determine module type and extract conv parameters from module class. + + This is more reliable than checking weight.ndim, especially for quantized layers + where weight shape might be different. + + Returns: + dict with keys: is_conv, conv_dim, stride, padding, dilation, groups + """ + info = { + "is_conv": False, + "conv_dim": 0, + "stride": (1,), + "padding": (0,), + "dilation": (1,), + "groups": 1, + "kernel_size": (1,), + "in_channels": None, + "out_channels": None, + } + + # Determine conv type + if isinstance(module, nn.Conv1d): + info["is_conv"] = True + info["conv_dim"] = 1 + elif isinstance(module, nn.Conv2d): + info["is_conv"] = True + info["conv_dim"] = 2 + elif isinstance(module, nn.Conv3d): + info["is_conv"] = True + info["conv_dim"] = 3 + elif isinstance(module, nn.Linear): + info["is_conv"] = False + info["conv_dim"] = 0 + else: + # Try to infer from class name for custom/quantized layers + class_name = type(module).__name__.lower() + if "conv3d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 3 + elif "conv2d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + elif "conv1d" in class_name: + info["is_conv"] = True + info["conv_dim"] = 1 + elif "conv" in class_name: + info["is_conv"] = True + info["conv_dim"] = 2 + + # Extract conv parameters if it's a conv layer + if info["is_conv"]: + # Try to get stride, padding, dilation, groups, kernel_size from module + info["stride"] = getattr(module, "stride", (1,) * info["conv_dim"]) + info["padding"] = getattr(module, "padding", (0,) * info["conv_dim"]) + info["dilation"] = getattr(module, "dilation", (1,) * info["conv_dim"]) + info["groups"] = getattr(module, "groups", 1) + info["kernel_size"] = getattr(module, "kernel_size", (1,) * info["conv_dim"]) + info["in_channels"] = getattr(module, "in_channels", None) + info["out_channels"] = getattr(module, "out_channels", None) + + # Ensure they're tuples + if isinstance(info["stride"], int): + info["stride"] = (info["stride"],) * info["conv_dim"] + if isinstance(info["padding"], int): + info["padding"] = (info["padding"],) * info["conv_dim"] + if isinstance(info["dilation"], int): + info["dilation"] = (info["dilation"],) * info["conv_dim"] + if isinstance(info["kernel_size"], int): + info["kernel_size"] = (info["kernel_size"],) * info["conv_dim"] + + return info + + +class BypassForwardHook: + """ + Hook that wraps a layer's forward to apply adapter in bypass mode. + + Stores the original forward and replaces it with bypass version. + + Supports both: + - WeightAdapterBase: Inference adapters (uses self.weights tuple) + - WeightAdapterTrainBase: Training adapters (nn.Module with parameters) + """ + + def __init__( + self, + module: nn.Module, + adapter: BypassAdapter, + multiplier: float = 1.0, + ): + self.module = module + self.adapter = adapter + self.multiplier = multiplier + self.original_forward = None + + # Determine layer type and conv params from module class (works for quantized layers) + module_info = get_module_type_info(module) + + # Set multiplier and layer type info on adapter for use in h() + adapter.multiplier = multiplier + adapter.is_conv = module_info["is_conv"] + adapter.conv_dim = module_info["conv_dim"] + adapter.kernel_size = module_info["kernel_size"] + adapter.in_channels = module_info["in_channels"] + adapter.out_channels = module_info["out_channels"] + # Store kw_dict for conv operations (like LyCORIS extra_args) + if module_info["is_conv"]: + adapter.kw_dict = { + "stride": module_info["stride"], + "padding": module_info["padding"], + "dilation": module_info["dilation"], + "groups": module_info["groups"], + } + else: + adapter.kw_dict = {} + + def _bypass_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + """Bypass forward: uses adapter's bypass_forward or default g(f(x) + h(x)) + + Note: + Bypass mode does NOT access original model weights (org_weight). + This is intentional - bypass mode is designed for quantized models + where weights may not be in a usable format. All necessary shape + information is provided via adapter attributes set during inject(). + """ + # Check if adapter has custom bypass_forward (e.g., GLoRA) + adapter_bypass = getattr(self.adapter, "bypass_forward", None) + if adapter_bypass is not None: + # Check if it's overridden (not the base class default) + # Need to check both base classes since adapter could be either type + adapter_type = type(self.adapter) + is_default_bypass = ( + adapter_type.bypass_forward is WeightAdapterBase.bypass_forward + or adapter_type.bypass_forward is WeightAdapterTrainBase.bypass_forward + ) + if not is_default_bypass: + return adapter_bypass(self.original_forward, x, *args, **kwargs) + + # Default bypass: g(f(x) + h(x, f(x))) + base_out = self.original_forward(x, *args, **kwargs) + h_out = self.adapter.h(x, base_out) + return self.adapter.g(base_out + h_out) + + def inject(self): + """Replace module forward with bypass version.""" + if self.original_forward is not None: + logging.debug( + f"[BypassHook] Already injected for {type(self.module).__name__}" + ) + return # Already injected + + # Move adapter weights to compute device (GPU) + # Use get_torch_device() instead of module.weight.device because + # with offloading, module weights may be on CPU while compute happens on GPU + device = comfy.model_management.get_torch_device() + + # Get dtype from module weight if available + dtype = None + if hasattr(self.module, "weight") and self.module.weight is not None: + dtype = self.module.weight.dtype + + # Only use dtype if it's a standard float type, not quantized + if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16): + dtype = None + + self._move_adapter_weights_to_device(device, dtype) + + self.original_forward = self.module.forward + self.module.forward = self._bypass_forward + logging.debug( + f"[BypassHook] Injected bypass forward for {type(self.module).__name__} (adapter={type(self.adapter).__name__})" + ) + + def _move_adapter_weights_to_device(self, device, dtype=None): + """Move adapter weights to specified device to avoid per-forward transfers. + + Handles both: + - WeightAdapterBase: has self.weights tuple of tensors + - WeightAdapterTrainBase: nn.Module with parameters, uses .to() method + """ + adapter = self.adapter + + # Check if adapter is an nn.Module (WeightAdapterTrainBase) + if isinstance(adapter, nn.Module): + # In training mode we don't touch dtype as trainer will handle it + adapter.to(device=device) + logging.debug( + f"[BypassHook] Moved training adapter (nn.Module) to {device}" + ) + return + + # WeightAdapterBase: handle self.weights tuple + if not hasattr(adapter, "weights") or adapter.weights is None: + return + + weights = adapter.weights + if isinstance(weights, (list, tuple)): + new_weights = [] + for w in weights: + if isinstance(w, torch.Tensor): + if dtype is not None: + new_weights.append(w.to(device=device, dtype=dtype)) + else: + new_weights.append(w.to(device=device)) + else: + new_weights.append(w) + adapter.weights = ( + tuple(new_weights) if isinstance(weights, tuple) else new_weights + ) + elif isinstance(weights, torch.Tensor): + if dtype is not None: + adapter.weights = weights.to(device=device, dtype=dtype) + else: + adapter.weights = weights.to(device=device) + + logging.debug(f"[BypassHook] Moved adapter weights to {device}") + + def eject(self): + """Restore original module forward.""" + if self.original_forward is None: + logging.debug(f"[BypassHook] Not injected for {type(self.module).__name__}") + return # Not injected + + self.module.forward = self.original_forward + self.original_forward = None + logging.debug( + f"[BypassHook] Ejected bypass forward for {type(self.module).__name__}" + ) + + +class BypassInjectionManager: + """ + Manages bypass mode injection for a collection of adapters. + + Creates PatcherInjection objects that can be used with ModelPatcher. + + Supports both inference adapters (WeightAdapterBase) and training adapters + (WeightAdapterTrainBase). + + Usage: + manager = BypassInjectionManager() + manager.add_adapter("model.layers.0.self_attn.q_proj", lora_adapter, strength=0.8) + manager.add_adapter("model.layers.0.self_attn.k_proj", lora_adapter, strength=0.8) + + injections = manager.create_injections(model) + model_patcher.set_injections("bypass_lora", injections) + """ + + def __init__(self): + self.adapters: dict[str, tuple[BypassAdapter, float]] = {} + self.hooks: list[BypassForwardHook] = [] + + def add_adapter( + self, + key: str, + adapter: BypassAdapter, + strength: float = 1.0, + ): + """ + Add an adapter for a specific weight key. + + Args: + key: Weight key (e.g., "model.layers.0.self_attn.q_proj.weight") + adapter: The weight adapter (LoRAAdapter, LoKrAdapter, etc.) + strength: Multiplier for adapter effect + """ + # Remove .weight suffix if present for module lookup + module_key = key + if module_key.endswith(".weight"): + module_key = module_key[:-7] + logging.debug( + f"[BypassManager] Stripped .weight suffix: {key} -> {module_key}" + ) + + self.adapters[module_key] = (adapter, strength) + logging.debug( + f"[BypassManager] Added adapter: {module_key} (type={type(adapter).__name__}, strength={strength})" + ) + + def clear_adapters(self): + """Remove all adapters.""" + self.adapters.clear() + + def _get_module_by_key(self, model: nn.Module, key: str) -> Optional[nn.Module]: + """Get a submodule by dot-separated key.""" + parts = key.split(".") + module = model + try: + for i, part in enumerate(parts): + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + logging.debug( + f"[BypassManager] Found module for key {key}: {type(module).__name__}" + ) + return module + except (AttributeError, IndexError, KeyError) as e: + logging.error(f"[BypassManager] Failed to find module for key {key}: {e}") + logging.error( + f"[BypassManager] Failed at part index {i}, part={part}, current module type={type(module).__name__}" + ) + return None + + def create_injections(self, model: nn.Module) -> list[PatcherInjection]: + """ + Create PatcherInjection objects for all registered adapters. + + Args: + model: The model to inject into (e.g., model_patcher.model) + + Returns: + List of PatcherInjection objects to use with model_patcher.set_injections() + """ + self.hooks.clear() + + logging.debug( + f"[BypassManager] create_injections called with {len(self.adapters)} adapters" + ) + logging.debug(f"[BypassManager] Model type: {type(model).__name__}") + + for key, (adapter, strength) in self.adapters.items(): + logging.debug(f"[BypassManager] Looking for module: {key}") + module = self._get_module_by_key(model, key) + + if module is None: + logging.warning(f"[BypassManager] Module not found for key {key}") + continue + + if not hasattr(module, "weight"): + logging.warning( + f"[BypassManager] Module {key} has no weight attribute (type={type(module).__name__})" + ) + continue + + logging.debug( + f"[BypassManager] Creating hook for {key} (module type={type(module).__name__}, weight shape={module.weight.shape})" + ) + hook = BypassForwardHook(module, adapter, multiplier=strength) + self.hooks.append(hook) + + logging.debug(f"[BypassManager] Created {len(self.hooks)} hooks") + + # Create single injection that manages all hooks + def inject_all(model_patcher): + logging.debug( + f"[BypassManager] inject_all called, injecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.inject() + logging.debug( + f"[BypassManager] Injected hook for {type(hook.module).__name__}" + ) + + def eject_all(model_patcher): + logging.debug( + f"[BypassManager] eject_all called, ejecting {len(self.hooks)} hooks" + ) + for hook in self.hooks: + hook.eject() + + return [PatcherInjection(inject=inject_all, eject=eject_all)] + + def get_hook_count(self) -> int: + """Return number of hooks that will be/are injected.""" + return len(self.hooks) + + +def create_bypass_injections_from_patches( + model: nn.Module, + patches: dict, + strength: float = 1.0, +) -> list[PatcherInjection]: + """ + Convenience function to create bypass injections from a patches dict. + + This is useful when you have patches in the format used by model_patcher.add_patches() + and want to apply them in bypass mode instead. + + Args: + model: The model to inject into + patches: Dict mapping weight keys to adapter data + strength: Global strength multiplier + + Returns: + List of PatcherInjection objects + """ + manager = BypassInjectionManager() + + for key, patch_list in patches.items(): + if not patch_list: + continue + + # patches format: list of (strength_patch, patch_data, strength_model, offset, function) + for patch in patch_list: + patch_strength, patch_data, strength_model, offset, function = patch + + # patch_data should be a WeightAdapterBase/WeightAdapterTrainBase or tuple + if isinstance(patch_data, (WeightAdapterBase, WeightAdapterTrainBase)): + adapter = patch_data + else: + # Skip non-adapter patches + continue + + combined_strength = strength * patch_strength + manager.add_adapter(key, adapter, strength=combined_strength) + + return manager.create_injections(model) diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py index 939abbba5..d6b97a23b 100644 --- a/comfy/weight_adapter/glora.py +++ b/comfy/weight_adapter/glora.py @@ -1,7 +1,8 @@ import logging -from typing import Optional +from typing import Callable, Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, weight_decompose @@ -29,7 +30,14 @@ class GLoRAAdapter(WeightAdapterBase): b1_name = "{}.b1.weight".format(x) b2_name = "{}.b2.weight".format(x) if a1_name in lora: - weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale) + weights = ( + lora[a1_name], + lora[a2_name], + lora[b1_name], + lora[b2_name], + alpha, + dora_scale, + ) loaded_keys.add(a1_name) loaded_keys.add(a2_name) loaded_keys.add(b1_name) @@ -58,16 +66,28 @@ class GLoRAAdapter(WeightAdapterBase): old_glora = True if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]: - if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]: + if ( + old_glora + and v[1].shape[0] == weight.shape[0] + and weight.shape[0] == weight.shape[1] + ): pass else: old_glora = False rank = v[1].shape[0] - a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype) - a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype) - b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype) - b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype) + a1 = comfy.model_management.cast_to_device( + v[0].flatten(start_dim=1), weight.device, intermediate_dtype + ) + a2 = comfy.model_management.cast_to_device( + v[1].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b1 = comfy.model_management.cast_to_device( + v[2].flatten(start_dim=1), weight.device, intermediate_dtype + ) + b2 = comfy.model_management.cast_to_device( + v[3].flatten(start_dim=1), weight.device, intermediate_dtype + ) if v[4] is not None: alpha = v[4] / rank @@ -76,18 +96,195 @@ class GLoRAAdapter(WeightAdapterBase): try: if old_glora: - lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora + lora_diff = ( + torch.mm(b2, b1) + + torch.mm( + torch.mm( + weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2 + ), + a1, + ) + ).reshape( + weight.shape + ) # old lycoris glora else: if weight.dim() > 2: - lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.einsum( + "o i ..., i j -> o j ...", + torch.einsum( + "o i ..., i j -> o j ...", + weight.to(dtype=intermediate_dtype), + a1, + ), + a2, + ).reshape(weight.shape) else: - lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape) + lora_diff = torch.mm( + torch.mm(weight.to(dtype=intermediate_dtype), a1), a2 + ).reshape(weight.shape) lora_diff += torch.mm(b1, b2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _compute_paths(self, x: torch.Tensor): + """ + Compute A path and B path outputs for GLoRA bypass. + + GLoRA: f(x) = Wx + WAx + Bx + - A path: a1(a2(x)) - modifies input to base forward + - B path: b1(b2(x)) - additive component + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Returns: (a_out, b_out) + """ + v = self.weights + # v = (a1, a2, b1, b2, alpha, dora_scale) + a1 = v[0] + a2 = v[1] + b1 = v[2] + b2 = v[3] + alpha = v[4] + + dtype = x.dtype + + # Cast dtype (weights should already be on correct device from inject()) + a1 = a1.to(dtype=dtype) + a2 = a2.to(dtype=dtype) + b1 = b1.to(dtype=dtype) + b2 = b2.to(dtype=dtype) + + # Determine rank and scale + # Check for old vs new glora format + old_glora = False + if b2.shape[1] == b1.shape[0] == a1.shape[0] == a2.shape[1]: + rank = a1.shape[0] + old_glora = True + + if b2.shape[0] == b1.shape[1] == a1.shape[1] == a2.shape[0]: + if old_glora and a2.shape[0] == x.shape[-1] and x.shape[-1] == x.shape[-1]: + pass + else: + old_glora = False + rank = a2.shape[0] + + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + + # Apply multiplier + multiplier = getattr(self, "multiplier", 1.0) + scale = scale * multiplier + + # Use module info from bypass injection, not input tensor shape + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + # Conv case - conv_dim is 1/2/3 for conv1d/2d/3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Get module's stride/padding for spatial dimension handling + module_stride = kw_dict.get("stride", (1,) * conv_dim) + module_padding = kw_dict.get("padding", (0,) * conv_dim) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Ensure weights are in conv shape + # a1, a2, b1 are always 1x1 kernels + if a1.ndim == 2: + a1 = a1.view(*a1.shape, *([1] * conv_dim)) + if a2.ndim == 2: + a2 = a2.view(*a2.shape, *([1] * conv_dim)) + if b1.ndim == 2: + b1 = b1.view(*b1.shape, *([1] * conv_dim)) + # b2 has actual kernel_size (like LoRA down) + if b2.ndim == 2: + if in_channels is not None: + b2 = b2.view(b2.shape[0], in_channels, *kernel_size) + else: + b2 = b2.view(*b2.shape, *([1] * conv_dim)) + + # A path: a2(x) -> a1(...) - 1x1 convs, no stride/padding needed, a_out is added to x + a2_out = conv_fn(x, a2) + a_out = conv_fn(a2_out, a1) * scale + + # B path: b2(x) with kernel/stride/padding -> b1(...) 1x1 + b2_out = conv_fn(x, b2, stride=module_stride, padding=module_padding) + b_out = conv_fn(b2_out, b1) * scale + else: + # Linear case + if old_glora: + # Old format: a1 @ a2 @ x, b2 @ b1 + a_out = F.linear(F.linear(x, a2), a1) * scale + b_out = F.linear(F.linear(x, b1), b2) * scale + else: + # New format: x @ a1 @ a2, b1 @ b2 + a_out = F.linear(F.linear(x, a1), a2) * scale + b_out = F.linear(F.linear(x, b2), b1) * scale + + return a_out, b_out + + def bypass_forward( + self, + org_forward: Callable, + x: torch.Tensor, + *args, + **kwargs, + ) -> torch.Tensor: + """ + GLoRA bypass forward: f(x + a(x)) + b(x) + + Unlike standard adapters, GLoRA modifies the input to the base forward + AND adds the B path output. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Reference: LyCORIS GLoRAModule._bypass_forward + """ + a_out, b_out = self._compute_paths(x) + + # Call base forward with modified input + base_out = org_forward(x + a_out, *args, **kwargs) + + # Add B path + return base_out + b_out + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + For GLoRA, h() returns the B path output. + + Note: + GLoRA's full bypass requires overriding bypass_forward() since + it also modifies the input to org_forward. This h() is provided for + compatibility but bypass_forward() should be used for correct behavior. + + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + _, b_out = self._compute_paths(x) + return b_out diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py index 0abb2d403..8007b7b44 100644 --- a/comfy/weight_adapter/loha.py +++ b/comfy/weight_adapter/loha.py @@ -1,11 +1,22 @@ import logging +from functools import cache from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose +@cache +def _warn_loha_bypass_inefficient(): + """One-time warning about LoHa bypass inefficiency.""" + logging.warning( + "LoHa bypass mode is inefficient: full weight diff is computed each forward pass. " + "Consider using LoRA or LoKr for training with bypass mode." + ) + + class HadaWeight(torch.autograd.Function): @staticmethod def forward(ctx, w1u, w1d, w2u, w2d, scale=torch.tensor(1)): @@ -105,9 +116,19 @@ class LohaDiff(WeightAdapterTrainBase): scale = self.alpha / self.rank if self.use_tucker: - diff_weight = HadaWeightTucker.apply(self.hada_t1, self.hada_w1_a, self.hada_w1_b, self.hada_t2, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeightTucker.apply( + self.hada_t1, + self.hada_w1_a, + self.hada_w1_b, + self.hada_t2, + self.hada_w2_a, + self.hada_w2_b, + scale, + ) else: - diff_weight = HadaWeight.apply(self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale) + diff_weight = HadaWeight.apply( + self.hada_w1_a, self.hada_w1_b, self.hada_w2_a, self.hada_w2_b, scale + ) # Add the scaled difference to the original weight weight = w.to(diff_weight) + diff_weight.reshape(w.shape) @@ -138,9 +159,7 @@ class LoHaAdapter(WeightAdapterBase): mat4 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.normal_(mat3, 0.1) torch.nn.init.normal_(mat4, 0.01) - return LohaDiff( - (mat1, mat2, alpha, mat3, mat4, None, None, None) - ) + return LohaDiff((mat1, mat2, alpha, mat3, mat4, None, None, None)) def to_train(self): return LohaDiff(self.weights) @@ -172,7 +191,16 @@ class LoHaAdapter(WeightAdapterBase): loaded_keys.add(hada_t1_name) loaded_keys.add(hada_t2_name) - weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale) + weights = ( + lora[hada_w1_a_name], + lora[hada_w1_b_name], + alpha, + lora[hada_w2_a_name], + lora[hada_w2_b_name], + hada_t1, + hada_t2, + dora_scale, + ) loaded_keys.add(hada_w1_a_name) loaded_keys.add(hada_w1_b_name) loaded_keys.add(hada_w2_a_name) @@ -203,30 +231,148 @@ class LoHaAdapter(WeightAdapterBase): w2a = v[3] w2b = v[4] dora_scale = v[7] - if v[5] is not None: #cp decomposition + if v[5] is not None: # cp decomposition t1 = v[5] t2 = v[6] - m1 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype)) + m1 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t1, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + ) - m2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype)) + m2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + ) else: - m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype)) - m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype)) + m1 = torch.mm( + comfy.model_management.cast_to_device( + w1a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1b, weight.device, intermediate_dtype + ), + ) + m2 = torch.mm( + comfy.model_management.cast_to_device( + w2a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2b, weight.device, intermediate_dtype + ), + ) try: lora_diff = (m1 * m2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoHa: h(x) = diff_weight @ x + + WARNING: Inefficient - computes full Hadamard product each forward. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/loha.py bypass_forward_diff + """ + _warn_loha_bypass_inefficient() + + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1a, v[1]=w1b, v[2]=alpha, v[3]=w2a, v[4]=w2b, v[5]=t1, v[6]=t2, v[7]=dora + w1a = v[0] + w1b = v[1] + alpha = v[2] + w2a = v[3] + w2b = v[4] + t1 = v[5] + t2 = v[6] + + # Compute scale + rank = w1b.shape[0] + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Cast dtype + w1a = w1a.to(dtype=x.dtype) + w1b = w1b.to(dtype=x.dtype) + w2a = w2a.to(dtype=x.dtype) + w2b = w2b.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Compute diff weight using Hadamard product + if t1 is not None and t2 is not None: + t1 = t1.to(dtype=x.dtype) + t2 = t2.to(dtype=x.dtype) + m1 = torch.einsum("i j k l, j r, i p -> p r k l", t1, w1b, w1a) + m2 = torch.einsum("i j k l, j r, i p -> p r k l", t2, w2b, w2a) + diff_weight = (m1 * m2) * scale + else: + m1 = w1a @ w1b + m2 = w2a @ w2b + diff_weight = (m1 * m2) * scale + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D diff_weight to conv format using kernel_size + # diff_weight: [out_channels, in_channels * prod(kernel_size)] -> [out_channels, in_channels, *kernel_size] + if diff_weight.dim() == 2: + if in_channels is not None: + diff_weight = diff_weight.view( + diff_weight.shape[0], in_channels, *kernel_size + ) + else: + diff_weight = diff_weight.view( + *diff_weight.shape, *([1] * conv_dim) + ) + else: + op = F.linear + kw_dict = {} + + return op(x, diff_weight, **kw_dict) diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py index 9b2aff2d7..b83750012 100644 --- a/comfy/weight_adapter/lokr.py +++ b/comfy/weight_adapter/lokr.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -14,7 +15,17 @@ from .base import ( class LokrDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) = weights + ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) = weights self.use_tucker = False if lokr_w1_a is not None: _, rank_a = lokr_w1_a.shape[0], lokr_w1_a.shape[1] @@ -57,10 +68,10 @@ class LokrDiff(WeightAdapterTrainBase): if self.w2_rebuild: if self.use_tucker: w2 = torch.einsum( - 'i j k l, j r, i p -> p r k l', + "i j k l, j r, i p -> p r k l", self.lokr_t2, self.lokr_w2_b, - self.lokr_w2_a + self.lokr_w2_a, ) else: w2 = self.lokr_w2_a @ self.lokr_w2_b @@ -69,9 +80,89 @@ class LokrDiff(WeightAdapterTrainBase): return self.lokr_w2 def __call__(self, w): - diff = torch.kron(self.w1, self.w2) + w1 = self.w1 + w2 = self.w2 + # Unsqueeze w1 to match w2 dims for proper kron product (like LyCORIS make_kron) + for _ in range(w2.dim() - w1.dim()): + w1 = w1.unsqueeze(-1) + diff = torch.kron(w1, w2) return w + diff.reshape(w.shape).to(w) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr training: efficient Kronecker product. + + Uses w1/w2 properties which handle both direct and decomposed cases. + For create_train (direct w1/w2), no alpha scaling in properties. + For to_train (decomposed), alpha/rank scaling is in properties. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Get w1, w2 from properties (handles rebuild vs direct) + w1 = self.w1 + w2 = self.w2 + + # Multiplier from bypass injection + multiplier = getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Efficient Kronecker application without materializing full weight + # kron(w1, w2) @ x can be computed as nested operations + # w1: [out_l, in_m], w2: [out_k, in_n, *k_size] + # Full weight would be [out_l*out_k, in_m*in_n, *k_size] + + uq = w1.size(1) # in_m - inner grouping dimension + + if is_conv: + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + B, C_in, *spatial = x.shape + # Reshape input for grouped application: [B * uq, C_in // uq, *spatial] + h_in_group = x.reshape(B * uq, -1, *spatial) + + # Ensure w2 has conv dims + if w2.dim() == 2: + w2 = w2.view(*w2.shape, *([1] * conv_dim)) + + # Apply w2 path with stride/padding + hb = conv_fn(h_in_group, w2, **kw_dict) + + # Reshape for cross-group operation + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross = hb.transpose(1, -1) + + # Apply w1 (always 2D, applied as linear on channel dim) + hc = F.linear(h_cross, w1) + hc = hc.transpose(1, -1) + + # Reshape to output + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + # Linear case + # Reshape input: [..., in_m * in_n] -> [..., uq (in_m), in_n] + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2: [..., uq, in_n] @ [out_k, in_n].T -> [..., uq, out_k] + hb = F.linear(h_in_group, w2) + + # Transpose for w1: [..., uq, out_k] -> [..., out_k, uq] + h_cross = hb.transpose(-1, -2) + + # Apply w1: [..., out_k, uq] @ [out_l, uq].T -> [..., out_k, out_l] + hc = F.linear(h_cross, w1) + + # Transpose back and flatten: [..., out_k, out_l] -> [..., out_l * out_k] + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * multiplier + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -86,16 +177,22 @@ class LoKrAdapter(WeightAdapterBase): @classmethod def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] - in_dim = weight.shape[1:].numel() - out1, out2 = factorization(out_dim, rank) - in1, in2 = factorization(in_dim, rank) - mat1 = torch.empty(out1, in1, device=weight.device, dtype=torch.float32) - mat2 = torch.empty(out2, in2, device=weight.device, dtype=torch.float32) + in_dim = weight.shape[1] # Just in_channels, not flattened with kernel + k_size = weight.shape[2:] if weight.dim() > 2 else () + + out_l, out_k = factorization(out_dim, rank) + in_m, in_n = factorization(in_dim, rank) + + # w1: [out_l, in_m] + mat1 = torch.empty(out_l, in_m, device=weight.device, dtype=torch.float32) + # w2: [out_k, in_n, *k_size] for conv, [out_k, in_n] for linear + mat2 = torch.empty( + out_k, in_n, *k_size, device=weight.device, dtype=torch.float32 + ) + torch.nn.init.kaiming_uniform_(mat2, a=5**0.5) torch.nn.init.constant_(mat1, 0.0) - return LokrDiff( - (mat1, mat2, alpha, None, None, None, None, None, None) - ) + return LokrDiff((mat1, mat2, alpha, None, None, None, None, None, None)) def to_train(self): return LokrDiff(self.weights) @@ -154,8 +251,23 @@ class LoKrAdapter(WeightAdapterBase): lokr_t2 = lora[lokr_t2_name] loaded_keys.add(lokr_t2_name) - if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None): - weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale) + if ( + (lokr_w1 is not None) + or (lokr_w2 is not None) + or (lokr_w1_a is not None) + or (lokr_w2_a is not None) + ): + weights = ( + lokr_w1, + lokr_w2, + alpha, + lokr_w1_a, + lokr_w1_b, + lokr_w2_a, + lokr_w2_b, + lokr_t2, + dora_scale, + ) return cls(loaded_keys, weights) else: return None @@ -184,23 +296,47 @@ class LoKrAdapter(WeightAdapterBase): if w1 is None: dim = w1_b.shape[0] - w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype)) + w1 = torch.mm( + comfy.model_management.cast_to_device( + w1_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w1_b, weight.device, intermediate_dtype + ), + ) else: - w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype) + w1 = comfy.model_management.cast_to_device( + w1, weight.device, intermediate_dtype + ) if w2 is None: dim = w2_b.shape[0] if t2 is None: - w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype)) + w2 = torch.mm( + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + ) else: - w2 = torch.einsum('i j k l, j r, i p -> p r k l', - comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype), - comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype)) + w2 = torch.einsum( + "i j k l, j r, i p -> p r k l", + comfy.model_management.cast_to_device( + t2, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_b, weight.device, intermediate_dtype + ), + comfy.model_management.cast_to_device( + w2_a, weight.device, intermediate_dtype + ), + ) else: - w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype) + w2 = comfy.model_management.cast_to_device( + w2, weight.device, intermediate_dtype + ) if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) @@ -212,9 +348,134 @@ class LoKrAdapter(WeightAdapterBase): try: lora_diff = torch.kron(w1, w2).reshape(weight.shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function(((strength * alpha) * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoKr: efficient Kronecker product application. + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/lokr.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=w1, v[1]=w2, v[2]=alpha, v[3]=w1_a, v[4]=w1_b, v[5]=w2_a, v[6]=w2_b, v[7]=t2, v[8]=dora + w1 = v[0] + w2 = v[1] + alpha = v[2] + w1_a = v[3] + w1_b = v[4] + w2_a = v[5] + w2_b = v[6] + t2 = v[7] + + use_w1 = w1 is not None + use_w2 = w2 is not None + tucker = t2 is not None + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) if is_conv else {} + + if is_conv: + op = FUNC_LIST[conv_dim + 2] + else: + op = F.linear + + # Determine rank and scale + rank = w1_b.size(0) if not use_w1 else w2_b.size(0) if not use_w2 else alpha + scale = (alpha / rank if alpha is not None else 1.0) * getattr( + self, "multiplier", 1.0 + ) + + # Build c (w1) + if use_w1: + c = w1.to(dtype=x.dtype) + else: + c = w1_a.to(dtype=x.dtype) @ w1_b.to(dtype=x.dtype) + uq = c.size(1) + + # Build w2 components + if use_w2: + ba = w2.to(dtype=x.dtype) + else: + a = w2_b.to(dtype=x.dtype) + b = w2_a.to(dtype=x.dtype) + if is_conv: + if tucker: + # Tucker: a, b get 1s appended (kernel is in t2) + if a.dim() == 2: + a = a.view(*a.shape, *([1] * conv_dim)) + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + else: + # Non-tucker conv: b may need 1s appended + if b.dim() == 2: + b = b.view(*b.shape, *([1] * conv_dim)) + + # Reshape input by uq groups + if is_conv: + B, _, *rest = x.shape + h_in_group = x.reshape(B * uq, -1, *rest) + else: + h_in_group = x.reshape(*x.shape[:-1], uq, -1) + + # Apply w2 path + if use_w2: + hb = op(h_in_group, ba, **kw_dict) + else: + if is_conv: + if tucker: + t = t2.to(dtype=x.dtype) + if t.dim() == 2: + t = t.view(*t.shape, *([1] * conv_dim)) + ha = op(h_in_group, a) + ht = op(ha, t, **kw_dict) + hb = op(ht, b) + else: + ha = op(h_in_group, a, **kw_dict) + hb = op(ha, b) + else: + ha = op(h_in_group, a) + hb = op(ha, b) + + # Reshape and apply c (w1) + if is_conv: + hb = hb.view(B, -1, *hb.shape[1:]) + h_cross_group = hb.transpose(1, -1) + else: + h_cross_group = hb.transpose(-1, -2) + + hc = F.linear(h_cross_group, c) + + if is_conv: + hc = hc.transpose(1, -1) + out = hc.reshape(B, -1, *hc.shape[3:]) + else: + hc = hc.transpose(-1, -2) + out = hc.reshape(*hc.shape[:-2], -1) + + return out * scale diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 3cc60bb1b..8e1261a12 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -2,6 +2,7 @@ import logging from typing import Optional import torch +import torch.nn.functional as F import comfy.model_management from .base import ( WeightAdapterBase, @@ -20,11 +21,7 @@ class LoraDiff(WeightAdapterTrainBase): rank, in_dim = mat2.shape[0], mat2.shape[1] if mid is not None: convdim = mid.ndim - 2 - layer = ( - torch.nn.Conv1d, - torch.nn.Conv2d, - torch.nn.Conv3d - )[convdim] + layer = (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)[convdim] else: layer = torch.nn.Linear self.lora_up = layer(rank, out_dim, bias=False) @@ -51,6 +48,78 @@ class LoraDiff(WeightAdapterTrainBase): weight = w + scale * diff.reshape(w.shape) return weight.to(org_dtype) + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA training: h(x) = up(down(x)) * scale + + Simple implementation using the nn.Module weights directly. + No mid/dora/reshape branches (create_train doesn't create them). + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + """ + # Compute scale = alpha / rank * multiplier + scale = (self.alpha / self.rank) * getattr(self, "multiplier", 1.0) + + # Get module info from bypass injection + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + # Get weights (keep in original dtype for numerical stability) + down_weight = self.lora_down.weight + up_weight = self.lora_up.weight + + if is_conv: + # Conv path: use functional conv + # conv_dim: 1=conv1d, 2=conv2d, 3=conv3d + conv_fn = (F.conv1d, F.conv2d, F.conv3d)[conv_dim - 1] + + # Reshape 2D weights to conv format if needed + # down: [rank, in_features] -> [rank, in_channels, *kernel_size] + # up: [out_features, rank] -> [out_features, rank, 1, 1, ...] + if down_weight.dim() == 2: + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + if in_channels is not None: + down_weight = down_weight.view( + down_weight.shape[0], in_channels, *kernel_size + ) + else: + # Fallback: assume 1x1 kernel + down_weight = down_weight.view( + *down_weight.shape, *([1] * conv_dim) + ) + if up_weight.dim() == 2: + # up always uses 1x1 kernel + up_weight = up_weight.view(*up_weight.shape, *([1] * conv_dim)) + + # down conv uses stride/padding from module, up is 1x1 + hidden = conv_fn(x, down_weight, **kw_dict) + + # mid layer if exists (tucker decomposition) + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + if mid_weight.dim() == 2: + mid_weight = mid_weight.view(*mid_weight.shape, *([1] * conv_dim)) + hidden = conv_fn(hidden, mid_weight) + + # up conv is always 1x1 (no stride/padding) + out = conv_fn(hidden, up_weight) + else: + # Linear path: simple matmul chain + hidden = F.linear(x, down_weight) + + # mid layer if exists + if self.lora_mid is not None: + mid_weight = self.lora_mid.weight + hidden = F.linear(hidden, mid_weight) + + out = F.linear(hidden, up_weight) + + return out * scale + def passive_memory_usage(self): return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -70,9 +139,7 @@ class LoRAAdapter(WeightAdapterBase): mat2 = torch.empty(rank, in_dim, device=weight.device, dtype=torch.float32) torch.nn.init.kaiming_uniform_(mat1, a=5**0.5) torch.nn.init.constant_(mat2, 0.0) - return LoraDiff( - (mat1, mat2, alpha, None, None, None) - ) + return LoraDiff((mat1, mat2, alpha, None, None, None)) def to_train(self): return LoraDiff(self.weights) @@ -147,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase): else: return None + def calculate_shape( + self, + key + ): + reshape = self.weights[5] + return tuple(reshape) if reshape is not None else None + def calculate_weight( self, weight, @@ -210,3 +284,85 @@ class LoRAAdapter(WeightAdapterBase): except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + Additive bypass component for LoRA: h(x) = up(down(x)) * scale + + Note: + Does not access original model weights - bypass mode is designed + for quantized models where weights may not be accessible. + + Args: + x: Input tensor + base_out: Output from base forward (unused, for API consistency) + + Reference: LyCORIS functional/locon.py bypass_forward_diff + """ + # FUNC_LIST: [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + FUNC_LIST = [None, None, F.linear, F.conv1d, F.conv2d, F.conv3d] + + v = self.weights + # v[0]=up, v[1]=down, v[2]=alpha, v[3]=mid, v[4]=dora_scale, v[5]=reshape + up = v[0] + down = v[1] + alpha = v[2] + mid = v[3] + + # Compute scale = alpha / rank + rank = down.shape[0] + if alpha is not None: + scale = alpha / rank + else: + scale = 1.0 + scale = scale * getattr(self, "multiplier", 1.0) + + # Cast dtype + up = up.to(dtype=x.dtype) + down = down.to(dtype=x.dtype) + + # Use module info from bypass injection, not weight dimension + is_conv = getattr(self, "is_conv", False) + conv_dim = getattr(self, "conv_dim", 0) + kw_dict = getattr(self, "kw_dict", {}) + + if is_conv: + op = FUNC_LIST[ + conv_dim + 2 + ] # conv_dim 1->conv1d(3), 2->conv2d(4), 3->conv3d(5) + kernel_size = getattr(self, "kernel_size", (1,) * conv_dim) + in_channels = getattr(self, "in_channels", None) + + # Reshape 2D weights to conv format using kernel_size + # down: [rank, in_channels * prod(kernel_size)] -> [rank, in_channels, *kernel_size] + # up: [out_channels, rank] -> [out_channels, rank, 1, 1, ...] (1x1 kernel) + if down.dim() == 2: + # down.shape[1] = in_channels * prod(kernel_size) + if in_channels is not None: + down = down.view(down.shape[0], in_channels, *kernel_size) + else: + # Fallback: assume 1x1 kernel if in_channels unknown + down = down.view(*down.shape, *([1] * conv_dim)) + if up.dim() == 2: + # up always uses 1x1 kernel + up = up.view(*up.shape, *([1] * conv_dim)) + if mid is not None: + mid = mid.to(dtype=x.dtype) + if mid.dim() == 2: + mid = mid.view(*mid.shape, *([1] * conv_dim)) + else: + op = F.linear + kw_dict = {} # linear doesn't take stride/padding + + # Simple chain: down -> mid (if tucker) -> up + if mid is not None: + if not is_conv: + mid = mid.to(dtype=x.dtype) + hidden = op(x, down) + hidden = op(hidden, mid, **kw_dict) + out = op(hidden, up) + else: + hidden = op(x, down, **kw_dict) + out = op(hidden, up) + + return out * scale diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py index c0aab9635..bc83cf8e8 100644 --- a/comfy/weight_adapter/oft.py +++ b/comfy/weight_adapter/oft.py @@ -3,13 +3,18 @@ from typing import Optional import torch import comfy.model_management -from .base import WeightAdapterBase, WeightAdapterTrainBase, weight_decompose, factorization +from .base import ( + WeightAdapterBase, + WeightAdapterTrainBase, + weight_decompose, + factorization, +) class OFTDiff(WeightAdapterTrainBase): def __init__(self, weights): super().__init__() - # Unpack weights tuple from LoHaAdapter + # Unpack weights tuple from OFTAdapter blocks, rescale, alpha, _ = weights # Create trainable parameters @@ -52,6 +57,78 @@ class OFTDiff(WeightAdapterTrainBase): weight = self.rescale * weight return weight.to(org_dtype) + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + blocks = self.oft_blocks.to(device=device, dtype=dtype) + I = torch.eye(self.block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if set + if self.constraint: + q_norm = torch.norm(q) + 1e-8 + if q_norm > self.constraint: + normed_q = q * self.constraint / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r.to(dtype) + + def h(self, x: torch.Tensor, base_out: torch.Tensor) -> torch.Tensor: + """ + OFT has no additive component - returns zeros matching base_out shape. + + OFT only transforms the output via g(), it doesn't add to it. + """ + return torch.zeros_like(base_out) + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation. + + OFT transforms output channels using block-diagonal orthogonal matrices. + """ + r = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(self.block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.reshape(*batch_shape, self.block_num, self.block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.reshape(*batch_shape, out_features) + + # Apply rescale if present + if self.rescaled: + rescale = self.rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out + def passive_memory_usage(self): """Calculates memory usage of the trainable parameters.""" return sum(param.numel() * param.element_size() for param in self.parameters()) @@ -68,10 +145,10 @@ class OFTAdapter(WeightAdapterBase): def create_train(cls, weight, rank=1, alpha=1.0): out_dim = weight.shape[0] block_size, block_num = factorization(out_dim, rank) - block = torch.zeros(block_num, block_size, block_size, device=weight.device, dtype=torch.float32) - return OFTDiff( - (block, None, alpha, None) + block = torch.zeros( + block_num, block_size, block_size, device=weight.device, dtype=torch.float32 ) + return OFTDiff((block, None, alpha, None)) def to_train(self): return OFTDiff(self.weights) @@ -127,9 +204,13 @@ class OFTAdapter(WeightAdapterBase): alpha = 0 dora_scale = v[3] - blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype) + blocks = comfy.model_management.cast_to_device( + blocks, weight.device, intermediate_dtype + ) if rescale is not None: - rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype) + rescale = comfy.model_management.cast_to_device( + rescale, weight.device, intermediate_dtype + ) block_num, block_size, *_ = blocks.shape @@ -139,23 +220,108 @@ class OFTAdapter(WeightAdapterBase): # for Q = -Q^T q = blocks - blocks.transpose(1, 2) normed_q = q - if alpha > 0: # alpha in oft/boft is for constraint + if alpha > 0: # alpha in oft/boft is for constraint q_norm = torch.norm(q) + 1e-8 if q_norm > alpha: normed_q = q * alpha / q_norm # use float() to prevent unsupported type in .inverse() r = (I + normed_q) @ (I - normed_q).float().inverse() r = r.to(weight) + # Create I in weight's dtype for the einsum + I_w = torch.eye(block_size, device=weight.device, dtype=weight.dtype) _, *shape = weight.shape lora_diff = torch.einsum( "k n m, k n ... -> k m ...", - (r * strength) - strength * I, + (r * strength) - strength * I_w, weight.view(block_num, block_size, *shape), ).view(-1, *shape) if dora_scale is not None: - weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function) + weight = weight_decompose( + dora_scale, + weight, + lora_diff, + alpha, + strength, + intermediate_dtype, + function, + ) else: weight += function((strength * lora_diff).type(weight.dtype)) except Exception as e: logging.error("ERROR {} {} {}".format(self.name, key, e)) return weight + + def _get_orthogonal_matrix(self, device, dtype): + """Compute the orthogonal rotation matrix R from OFT blocks.""" + v = self.weights + blocks = v[0].to(device=device, dtype=dtype) + alpha = v[2] + if alpha is None: + alpha = 0 + + block_num, block_size, _ = blocks.shape + I = torch.eye(block_size, device=device, dtype=dtype) + + # Q = blocks - blocks^T (skew-symmetric) + q = blocks - blocks.transpose(1, 2) + normed_q = q + + # Apply constraint if alpha > 0 + if alpha > 0: + q_norm = torch.norm(q) + 1e-8 + if q_norm > alpha: + normed_q = q * alpha / q_norm + + # Cayley transform: R = (I + Q)(I - Q)^-1 + r = (I + normed_q) @ (I - normed_q).float().inverse() + return r, block_num, block_size + + def g(self, y: torch.Tensor) -> torch.Tensor: + """ + Output transformation for OFT: applies orthogonal rotation to output. + + OFT transforms the output channels using block-diagonal orthogonal matrices. + + Reference: LyCORIS DiagOFTModule._bypass_forward + """ + v = self.weights + rescale = v[1] + + r, block_num, block_size = self._get_orthogonal_matrix(y.device, y.dtype) + + # Apply multiplier to interpolate between identity and full transform + multiplier = getattr(self, "multiplier", 1.0) + I = torch.eye(block_size, device=y.device, dtype=y.dtype) + r = r * multiplier + (1 - multiplier) * I + + # Use module info from bypass injection to determine conv vs linear + is_conv = getattr(self, "is_conv", y.dim() > 2) + + if is_conv: + # Conv output: (N, C, H, W, ...) -> transpose to (N, H, W, ..., C) + y = y.transpose(1, -1) + + # y now has channels in last dim + *batch_shape, out_features = y.shape + + # Reshape to apply block-diagonal transform + # (*, out_features) -> (*, block_num, block_size) + y_blocked = y.view(*batch_shape, block_num, block_size) + + # Apply orthogonal transform: R @ y for each block + # r: (block_num, block_size, block_size), y_blocked: (*, block_num, block_size) + out_blocked = torch.einsum("k n m, ... k n -> ... k m", r, y_blocked) + + # Reshape back: (*, block_num, block_size) -> (*, out_features) + out = out_blocked.view(*batch_shape, out_features) + + # Apply rescale if present + if rescale is not None: + rescale = rescale.to(device=y.device, dtype=y.dtype) + out = out * rescale.view(-1) + + if is_conv: + # Transpose back: (N, H, W, ..., C) -> (N, C, H, W, ...) + out = out.transpose(1, -1) + + return out diff --git a/comfy/windows.py b/comfy/windows.py new file mode 100644 index 000000000..213dc481d --- /dev/null +++ b/comfy/windows.py @@ -0,0 +1,52 @@ +import ctypes +import logging +import psutil +from ctypes import wintypes + +import comfy_aimdo.control + +psapi = ctypes.WinDLL("psapi") +kernel32 = ctypes.WinDLL("kernel32") + +class PERFORMANCE_INFORMATION(ctypes.Structure): + _fields_ = [ + ("cb", wintypes.DWORD), + ("CommitTotal", ctypes.c_size_t), + ("CommitLimit", ctypes.c_size_t), + ("CommitPeak", ctypes.c_size_t), + ("PhysicalTotal", ctypes.c_size_t), + ("PhysicalAvailable", ctypes.c_size_t), + ("SystemCache", ctypes.c_size_t), + ("KernelTotal", ctypes.c_size_t), + ("KernelPaged", ctypes.c_size_t), + ("KernelNonpaged", ctypes.c_size_t), + ("PageSize", ctypes.c_size_t), + ("HandleCount", wintypes.DWORD), + ("ProcessCount", wintypes.DWORD), + ("ThreadCount", wintypes.DWORD), + ] + +def get_free_ram(): + #Windows is way too conservative and chalks recently used uncommitted model RAM + #as "in-use". So, calculate free RAM for the sake of general use as the greater of: + # + #1: What psutil says + #2: Total Memory - (Committed Memory - VRAM in use) + # + #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked + #commit charge for all VRAM used just incase it wants to page it all out. This just + #isn't realistic so "overcommit" on our calculations by just subtracting it off. + + pi = PERFORMANCE_INFORMATION() + pi.cb = ctypes.sizeof(pi) + + if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): + logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") + return psutil.virtual_memory().available + + committed = pi.CommitTotal * pi.PageSize + total = pi.PhysicalTotal * pi.PageSize + + return max(psutil.virtual_memory().available, + total - (committed - comfy_aimdo.control.get_total_vram_usage())) + diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index de167f037..9f6918315 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -14,6 +14,8 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "extension": {"manager": {"supports_v4": True}}, + "node_replacements": True, + "assets": args.enable_assets, } diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 68ff78270..16d4acfd1 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -5,6 +5,10 @@ from comfy_api.latest._input import ( MaskInput, LatentInput, VideoInput, + CurvePoint, + CurveInput, + MonotoneCubicCurve, + LinearCurve, ) __all__ = [ @@ -13,4 +17,8 @@ __all__ = [ "MaskInput", "LatentInput", "VideoInput", + "CurvePoint", + "CurveInput", + "MonotoneCubicCurve", + "LinearCurve", ] diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index fab63c7df..04973fea0 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -7,10 +7,9 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents -from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D from . import _io_public as io from . import _ui_public as ui -# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple from PIL import Image @@ -22,6 +21,18 @@ class ComfyAPI_latest(ComfyAPIBase): VERSION = "latest" STABLE = False + def __init__(self): + super().__init__() + self.node_replacement = self.NodeReplacement() + self.execution = self.Execution() + self.caching = self.Caching() + + class NodeReplacement(ProxiedSingleton): + async def register(self, node_replace: io.NodeReplace) -> None: + """Register a node replacement mapping.""" + from server import PromptServer + PromptServer.instance.node_replace_manager.register(node_replace) + class Execution(ProxiedSingleton): async def set_progress( self, @@ -74,7 +85,35 @@ class ComfyAPI_latest(ComfyAPIBase): image=to_display, ) - execution: Execution + class Caching(ProxiedSingleton): + """ + External cache provider API for sharing cached node outputs + across ComfyUI instances. + + Example:: + + from comfy_api.latest import Caching + + class MyCacheProvider(Caching.CacheProvider): + async def on_lookup(self, context): + ... # check external storage + + async def on_store(self, context, value): + ... # store to external storage + + Caching.register_provider(MyCacheProvider()) + """ + from ._caching import CacheProvider, CacheContext, CacheValue + + async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Register an external cache provider. Providers are called in registration order.""" + from comfy_execution.cache_provider import register_cache_provider + register_cache_provider(provider) + + async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None: + """Unregister a previously registered cache provider.""" + from comfy_execution.cache_provider import unregister_cache_provider + unregister_cache_provider(provider) class ComfyExtension(ABC): async def on_load(self) -> None: @@ -106,6 +145,10 @@ class Types: VideoComponents = VideoComponents MESH = MESH VOXEL = VOXEL + File3D = File3D + + +Caching = ComfyAPI_latest.Caching ComfyAPI = ComfyAPI_latest @@ -126,6 +169,7 @@ __all__ = [ "Input", "InputImpl", "Types", + "Caching", "ComfyExtension", "io", "IO", diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py new file mode 100644 index 000000000..30c8848cd --- /dev/null +++ b/comfy_api/latest/_caching.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +from typing import Optional +from dataclasses import dataclass + + +@dataclass +class CacheContext: + node_id: str + class_type: str + cache_key_hash: str # SHA256 hex digest + + +@dataclass +class CacheValue: + outputs: list + ui: dict = None + + +class CacheProvider(ABC): + """Abstract base class for external cache providers. + Exceptions from provider methods are caught by the caller and never break execution. + """ + + @abstractmethod + async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]: + """Called on local cache miss. Return CacheValue if found, None otherwise.""" + pass + + @abstractmethod + async def on_store(self, context: CacheContext, value: CacheValue) -> None: + """Called after local store. Dispatched via asyncio.create_task.""" + pass + + def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool: + """Return False to skip external caching for this node. Default: True.""" + return True + + def on_prompt_start(self, prompt_id: str) -> None: + pass + + def on_prompt_end(self, prompt_id: str) -> None: + pass diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py index 14f0e72f4..05cd3d40a 100644 --- a/comfy_api/latest/_input/__init__.py +++ b/comfy_api/latest/_input/__init__.py @@ -1,4 +1,5 @@ from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput +from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve from .video_types import VideoInput __all__ = [ @@ -7,4 +8,8 @@ __all__ = [ "VideoInput", "MaskInput", "LatentInput", + "CurvePoint", + "CurveInput", + "MonotoneCubicCurve", + "LinearCurve", ] diff --git a/comfy_api/latest/_input/curve_types.py b/comfy_api/latest/_input/curve_types.py new file mode 100644 index 000000000..b6dd7adf9 --- /dev/null +++ b/comfy_api/latest/_input/curve_types.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import logging +import math +from abc import ABC, abstractmethod +import numpy as np + +logger = logging.getLogger(__name__) + + +CurvePoint = tuple[float, float] + + +class CurveInput(ABC): + """Abstract base class for curve inputs. + + Subclasses represent different curve representations (control-point + interpolation, analytical functions, LUT-based, etc.) while exposing a + uniform evaluation interface to downstream nodes. + """ + + @property + @abstractmethod + def points(self) -> list[CurvePoint]: + """The control points that define this curve.""" + + @abstractmethod + def interp(self, x: float) -> float: + """Evaluate the curve at a single *x* value in [0, 1].""" + + def interp_array(self, xs: np.ndarray) -> np.ndarray: + """Vectorised evaluation over a numpy array of x values. + + Subclasses should override this for better performance. The default + falls back to scalar ``interp`` calls. + """ + return np.fromiter((self.interp(float(x)) for x in xs), dtype=np.float64, count=len(xs)) + + def to_lut(self, size: int = 256) -> np.ndarray: + """Generate a float64 lookup table of *size* evenly-spaced samples in [0, 1].""" + return self.interp_array(np.linspace(0.0, 1.0, size)) + + @staticmethod + def from_raw(data) -> CurveInput: + """Convert raw curve data (dict or point list) to a CurveInput instance. + + Accepts: + - A ``CurveInput`` instance (returned as-is). + - A dict with ``"points"`` and optional ``"interpolation"`` keys. + - A bare list/sequence of ``(x, y)`` pairs (defaults to monotone cubic). + """ + if isinstance(data, CurveInput): + return data + if isinstance(data, dict): + raw_points = data["points"] + interpolation = data.get("interpolation", "monotone_cubic") + else: + raw_points = data + interpolation = "monotone_cubic" + points = [(float(x), float(y)) for x, y in raw_points] + if interpolation == "linear": + return LinearCurve(points) + if interpolation != "monotone_cubic": + logger.warning("Unknown curve interpolation %r, falling back to monotone_cubic", interpolation) + return MonotoneCubicCurve(points) + + +class MonotoneCubicCurve(CurveInput): + """Monotone cubic Hermite interpolation over control points. + + Mirrors the frontend ``createMonotoneInterpolator`` in + ``ComfyUI_frontend/src/components/curve/curveUtils.ts`` so that + backend evaluation matches the editor preview exactly. + + All heavy work (sorting, slope computation) happens once at construction. + ``interp_array`` is fully vectorised with numpy. + """ + + def __init__(self, control_points: list[CurvePoint]): + sorted_pts = sorted(control_points, key=lambda p: p[0]) + self._points = [(float(x), float(y)) for x, y in sorted_pts] + self._xs = np.array([p[0] for p in self._points], dtype=np.float64) + self._ys = np.array([p[1] for p in self._points], dtype=np.float64) + self._slopes = self._compute_slopes() + + @property + def points(self) -> list[CurvePoint]: + return list(self._points) + + def _compute_slopes(self) -> np.ndarray: + xs, ys = self._xs, self._ys + n = len(xs) + if n < 2: + return np.zeros(n, dtype=np.float64) + + dx = np.diff(xs) + dy = np.diff(ys) + dx_safe = np.where(dx == 0, 1.0, dx) + deltas = np.where(dx == 0, 0.0, dy / dx_safe) + + slopes = np.empty(n, dtype=np.float64) + slopes[0] = deltas[0] + slopes[-1] = deltas[-1] + for i in range(1, n - 1): + if deltas[i - 1] * deltas[i] <= 0: + slopes[i] = 0.0 + else: + slopes[i] = (deltas[i - 1] + deltas[i]) / 2 + + for i in range(n - 1): + if deltas[i] == 0: + slopes[i] = 0.0 + slopes[i + 1] = 0.0 + else: + alpha = slopes[i] / deltas[i] + beta = slopes[i + 1] / deltas[i] + s = alpha * alpha + beta * beta + if s > 9: + t = 3 / math.sqrt(s) + slopes[i] = t * alpha * deltas[i] + slopes[i + 1] = t * beta * deltas[i] + return slopes + + def interp(self, x: float) -> float: + xs, ys, slopes = self._xs, self._ys, self._slopes + n = len(xs) + if n == 0: + return 0.0 + if n == 1: + return float(ys[0]) + if x <= xs[0]: + return float(ys[0]) + if x >= xs[-1]: + return float(ys[-1]) + + hi = int(np.searchsorted(xs, x, side='right')) + hi = min(hi, n - 1) + lo = hi - 1 + + dx = xs[hi] - xs[lo] + if dx == 0: + return float(ys[lo]) + + t = (x - xs[lo]) / dx + t2 = t * t + t3 = t2 * t + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + return float(h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi]) + + def interp_array(self, xs_in: np.ndarray) -> np.ndarray: + """Fully vectorised evaluation using numpy.""" + xs, ys, slopes = self._xs, self._ys, self._slopes + n = len(xs) + if n == 0: + return np.zeros_like(xs_in, dtype=np.float64) + if n == 1: + return np.full_like(xs_in, ys[0], dtype=np.float64) + + hi = np.searchsorted(xs, xs_in, side='right').clip(1, n - 1) + lo = hi - 1 + + dx = xs[hi] - xs[lo] + dx_safe = np.where(dx == 0, 1.0, dx) + t = np.where(dx == 0, 0.0, (xs_in - xs[lo]) / dx_safe) + t2 = t * t + t3 = t2 * t + + h00 = 2 * t3 - 3 * t2 + 1 + h10 = t3 - 2 * t2 + t + h01 = -2 * t3 + 3 * t2 + h11 = t3 - t2 + + result = h00 * ys[lo] + h10 * dx * slopes[lo] + h01 * ys[hi] + h11 * dx * slopes[hi] + result = np.where(xs_in <= xs[0], ys[0], result) + result = np.where(xs_in >= xs[-1], ys[-1], result) + return result + + def __repr__(self) -> str: + return f"MonotoneCubicCurve(points={self._points})" + + +class LinearCurve(CurveInput): + """Piecewise linear interpolation over control points. + + Mirrors the frontend ``createLinearInterpolator`` in + ``ComfyUI_frontend/src/components/curve/curveUtils.ts``. + """ + + def __init__(self, control_points: list[CurvePoint]): + sorted_pts = sorted(control_points, key=lambda p: p[0]) + self._points = [(float(x), float(y)) for x, y in sorted_pts] + self._xs = np.array([p[0] for p in self._points], dtype=np.float64) + self._ys = np.array([p[1] for p in self._points], dtype=np.float64) + + @property + def points(self) -> list[CurvePoint]: + return list(self._points) + + def interp(self, x: float) -> float: + xs, ys = self._xs, self._ys + n = len(xs) + if n == 0: + return 0.0 + if n == 1: + return float(ys[0]) + return float(np.interp(x, xs, ys)) + + def interp_array(self, xs_in: np.ndarray) -> np.ndarray: + if len(self._xs) == 0: + return np.zeros_like(xs_in, dtype=np.float64) + if len(self._xs) == 1: + return np.full_like(xs_in, self._ys[0], dtype=np.float64) + return np.interp(xs_in, self._xs, self._ys) + + def __repr__(self) -> str: + return f"LinearCurve(points={self._points})" diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index e634a0311..451e9526e 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -34,6 +34,21 @@ class VideoInput(ABC): """ pass + @abstractmethod + def as_trimmed( + self, + start_time: float | None = None, + duration: float | None = None, + strict_duration: bool = False, + ) -> VideoInput | None: + """ + Create a new VideoInput which is trimmed to have the corresponding start_time and duration + + Returns: + A new VideoInput, or None if the result would have negative duration + """ + pass + def get_stream_source(self) -> Union[str, io.BytesIO]: """ Get a streamable source for the video. This allows processing without diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index ea35c6062..1b4993aa7 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -6,6 +6,7 @@ from typing import Optional from .._input import AudioInput, VideoInput import av import io +import itertools import json import numpy as np import math @@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None: formats = container_format.split(",") return formats[0] - def get_open_write_kwargs( dest: str | io.BytesIO, container_format: str, to_format: str | None ) -> dict: @@ -57,12 +57,14 @@ class VideoFromFile(VideoInput): Class representing video input from a file. """ - def __init__(self, file: str | io.BytesIO): + def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0): """ Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object containing the file contents. """ self.__file = file + self.__start_time = start_time + self.__duration = duration def get_stream_source(self) -> str | io.BytesIO: """ @@ -96,6 +98,16 @@ class VideoFromFile(VideoInput): Returns: Duration in seconds """ + raw_duration = self._get_raw_duration() + if self.__start_time < 0: + duration_from_start = min(raw_duration, -self.__start_time) + else: + duration_from_start = raw_duration - self.__start_time + if self.__duration: + return min(self.__duration, duration_from_start) + return duration_from_start + + def _get_raw_duration(self) -> float: if isinstance(self.__file, io.BytesIO): self.__file.seek(0) with av.open(self.__file, mode="r") as container: @@ -113,9 +125,13 @@ class VideoFromFile(VideoInput): if video_stream and video_stream.average_rate: frame_count = 0 container.seek(0) - for packet in container.demux(video_stream): - for _ in packet.decode(): - frame_count += 1 + frame_iterator = ( + container.decode(video_stream) + if video_stream.codec.capabilities & 0x100 + else container.demux(video_stream) + ) + for packet in frame_iterator: + frame_count += 1 if frame_count > 0: return float(frame_count / video_stream.average_rate) @@ -131,36 +147,54 @@ class VideoFromFile(VideoInput): with av.open(self.__file, mode="r") as container: video_stream = self._get_first_video_stream(container) - # 1. Prefer the frames field if available - if video_stream.frames and video_stream.frames > 0: + # 1. Prefer the frames field if available and usable + if ( + video_stream.frames + and video_stream.frames > 0 + and not self.__start_time + and not self.__duration + ): return int(video_stream.frames) # 2. Try to estimate from duration and average_rate using only metadata - if container.duration is not None and video_stream.average_rate: - duration_seconds = float(container.duration / av.time_base) - estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) - if estimated_frames > 0: - return estimated_frames - if ( getattr(video_stream, "duration", None) is not None and getattr(video_stream, "time_base", None) is not None and video_stream.average_rate ): - duration_seconds = float(video_stream.duration * video_stream.time_base) + raw_duration = float(video_stream.duration * video_stream.time_base) + if self.__start_time < 0: + duration_from_start = min(raw_duration, -self.__start_time) + else: + duration_from_start = raw_duration - self.__start_time + duration_seconds = min(self.__duration, duration_from_start) estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) if estimated_frames > 0: return estimated_frames # 3. Last resort: decode frames and count them (streaming) - frame_count = 0 - container.seek(0) - for packet in container.demux(video_stream): - for _ in packet.decode(): - frame_count += 1 - - if frame_count == 0: - raise ValueError(f"Could not determine frame count for file '{self.__file}'") + if self.__start_time < 0: + start_time = max(self._get_raw_duration() + self.__start_time, 0) + else: + start_time = self.__start_time + frame_count = 1 + start_pts = int(start_time / video_stream.time_base) + end_pts = int((start_time + self.__duration) / video_stream.time_base) + container.seek(start_pts, stream=video_stream) + frame_iterator = ( + container.decode(video_stream) + if video_stream.codec.capabilities & 0x100 + else container.demux(video_stream) + ) + for frame in frame_iterator: + if frame.pts >= start_pts: + break + else: + raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}") + for frame in frame_iterator: + if frame.pts >= end_pts: + break + frame_count += 1 return frame_count def get_frame_rate(self) -> Fraction: @@ -199,9 +233,21 @@ class VideoFromFile(VideoInput): return container.format.name def get_components_internal(self, container: InputContainer) -> VideoComponents: + video_stream = self._get_first_video_stream(container) + if self.__start_time < 0: + start_time = max(self._get_raw_duration() + self.__start_time, 0) + else: + start_time = self.__start_time # Get video frames frames = [] - for frame in container.decode(video=0): + start_pts = int(start_time / video_stream.time_base) + end_pts = int((start_time + self.__duration) / video_stream.time_base) + container.seek(start_pts, stream=video_stream) + for frame in container.decode(video_stream): + if frame.pts < start_pts: + continue + if self.__duration and frame.pts >= end_pts: + break img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) frames.append(img) @@ -209,31 +255,44 @@ class VideoFromFile(VideoInput): images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) # Get frame rate - video_stream = next(s for s in container.streams if s.type == 'video') - frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1) + frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) # Get audio if available audio = None - try: - container.seek(0) # Reset the container to the beginning - for stream in container.streams: - if stream.type != 'audio': - continue - assert isinstance(stream, av.AudioStream) - audio_frames = [] - for packet in container.demux(stream): - for frame in packet.decode(): - assert isinstance(frame, av.AudioFrame) - audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) - if len(audio_frames) > 0: - audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) - audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) - audio = AudioInput({ - "waveform": audio_tensor, - "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1, - }) - except StopIteration: - pass # No audio stream + container.seek(start_pts, stream=video_stream) + # Use last stream for consistency + if len(container.streams.audio): + audio_stream = container.streams.audio[-1] + audio_frames = [] + resample = av.audio.resampler.AudioResampler(format='fltp').resample + frames = itertools.chain.from_iterable( + map(resample, container.decode(audio_stream)) + ) + + has_first_frame = False + for frame in frames: + offset_seconds = start_time - frame.pts * audio_stream.time_base + to_skip = max(0, int(offset_seconds * audio_stream.sample_rate)) + if to_skip < frame.samples: + has_first_frame = True + break + if has_first_frame: + audio_frames.append(frame.to_ndarray()[..., to_skip:]) + + for frame in frames: + if self.__duration and frame.time > start_time + self.__duration: + break + audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) + if len(audio_frames) > 0: + audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples) + if self.__duration: + audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)] + + audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples) + audio = AudioInput({ + "waveform": audio_tensor, + "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1, + }) metadata = container.metadata return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) @@ -250,7 +309,7 @@ class VideoFromFile(VideoInput): path: str | io.BytesIO, format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None + metadata: Optional[dict] = None, ): if isinstance(self.__file, io.BytesIO): self.__file.seek(0) # Reset the BytesIO object to the beginning @@ -262,15 +321,14 @@ class VideoFromFile(VideoInput): reuse_streams = False if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None: reuse_streams = False + if self.__start_time or self.__duration: + reuse_streams = False if not reuse_streams: components = self.get_components_internal(container) video = VideoFromComponents(components) return video.save_to( - path, - format=format, - codec=codec, - metadata=metadata + path, format=format, codec=codec, metadata=metadata ) streams = container.streams @@ -304,10 +362,21 @@ class VideoFromFile(VideoInput): output_container.mux(packet) def _get_first_video_stream(self, container: InputContainer): - video_stream = next((s for s in container.streams if s.type == "video"), None) - if video_stream is None: - raise ValueError(f"No video stream found in file '{self.__file}'") - return video_stream + if len(container.streams.video): + return container.streams.video[0] + raise ValueError(f"No video stream found in file '{self.__file}'") + + def as_trimmed( + self, start_time: float = 0, duration: float = 0, strict_duration: bool = True + ) -> VideoInput | None: + trimmed = VideoFromFile( + self.get_stream_source(), + start_time=start_time + self.__start_time, + duration=duration, + ) + if trimmed.get_duration() < duration and strict_duration: + return None + return trimmed class VideoFromComponents(VideoInput): @@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput): return VideoComponents( images=self.__components.images, audio=self.__components.audio, - frame_rate=self.__components.frame_rate + frame_rate=self.__components.frame_rate, ) def save_to( @@ -330,8 +399,9 @@ class VideoFromComponents(VideoInput): path: str, format: VideoContainer = VideoContainer.AUTO, codec: VideoCodec = VideoCodec.AUTO, - metadata: Optional[dict] = None + metadata: Optional[dict] = None, ): + """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -339,6 +409,10 @@ class VideoFromComponents(VideoInput): extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value + elif isinstance(path, io.BytesIO): + # BytesIO has no file extension, so av.open can't infer the format. + # Default to mp4 since that's the only supported format anyway. + extra_kwargs["format"] = "mp4" with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: @@ -357,7 +431,10 @@ class VideoFromComponents(VideoInput): audio_stream: Optional[av.AudioStream] = None if self.__components.audio: audio_sample_rate = int(self.__components.audio['sample_rate']) - audio_stream = output.add_stream('aac', rate=audio_sample_rate) + waveform = self.__components.audio['waveform'] + waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] + layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo') + audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout) # Encode video for i, frame in enumerate(self.__components.images): @@ -372,12 +449,21 @@ class VideoFromComponents(VideoInput): output.mux(packet) if audio_stream and self.__components.audio: - waveform = self.__components.audio['waveform'] - waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])] - frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo') + frame = av.AudioFrame.from_ndarray(waveform.float().cpu().contiguous().numpy(), format='fltp', layout=layout) frame.sample_rate = audio_sample_rate frame.pts = 0 output.mux(audio_stream.encode(frame)) # Flush encoder output.mux(audio_stream.encode(None)) + + def as_trimmed( + self, + start_time: float | None = None, + duration: float | None = None, + strict_duration: bool = True, + ) -> VideoInput | None: + if self.get_duration() < start_time + duration: + return None + #TODO Consider tracking duration and trimming at time of save? + return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration) diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 4b14e5ded..1cbc8ed26 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -23,14 +23,12 @@ if TYPE_CHECKING: from comfy.samplers import CFGGuider, Sampler from comfy.sd import CLIP, VAE from comfy.sd import StyleModel as StyleModel_ - from comfy_api.input import VideoInput + from comfy_api.input import VideoInput, CurveInput as CurveInput_ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) -from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL +from ._util import MESH, VOXEL, SVG as _SVG, File3D -# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference class FolderType(str, Enum): input = "input" @@ -75,17 +73,14 @@ class RemoteOptions: class NumberDisplay(str, Enum): number = "number" slider = "slider" + gradient_slider = "gradientslider" -class _StringIOType(str): - def __ne__(self, value: object) -> bool: - if self == "*" or value == "*": - return False - if not isinstance(value, str): - return True - a = frozenset(self.split(",")) - b = frozenset(value.split(",")) - return not (b.issubset(a) or a.issubset(b)) +class ControlAfterGenerate(str, Enum): + fixed = "fixed" + increment = "increment" + decrement = "decrement" + randomize = "randomize" class _ComfyType(ABC): Type = Any @@ -126,8 +121,7 @@ def comfytype(io_type: str, **kwargs): new_cls.__module__ = cls.__module__ new_cls.__doc__ = cls.__doc__ # assign ComfyType attributes, if needed - # NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) - new_cls.io_type = _StringIOType(io_type) + new_cls.io_type = io_type if hasattr(new_cls, "Input") and new_cls.Input is not None: new_cls.Input.Parent = new_cls if hasattr(new_cls, "Output") and new_cls.Output is not None: @@ -166,7 +160,7 @@ class Input(_IO_V3): ''' Base class for a V3 Input. ''' - def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): super().__init__() self.id = id self.display_name = display_name @@ -174,6 +168,8 @@ class Input(_IO_V3): self.tooltip = tooltip self.lazy = lazy self.extra_dict = extra_dict if extra_dict is not None else {} + self.rawLink = raw_link + self.advanced = advanced def as_dict(self): return prune_dict({ @@ -181,10 +177,12 @@ class Input(_IO_V3): "optional": self.optional, "tooltip": self.tooltip, "lazy": self.lazy, + "rawLink": self.rawLink, + "advanced": self.advanced, }) | prune_dict(self.extra_dict) def get_io_type(self): - return _StringIOType(self.io_type) + return self.io_type def get_all(self) -> list[Input]: return [self] @@ -195,8 +193,8 @@ class WidgetInput(Input): ''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: Any=None, - socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + socketless: bool=None, widget_type: str=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.default = default self.socketless = socketless self.widget_type = widget_type @@ -218,13 +216,14 @@ class Output(_IO_V3): def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): self.id = id - self.display_name = display_name + self.display_name = display_name if display_name else id self.tooltip = tooltip self.is_output_list = is_output_list def as_dict(self): + display_name = self.display_name if self.display_name else self.id return prune_dict({ - "display_name": self.display_name, + "display_name": display_name, "tooltip": self.tooltip, "is_output_list": self.is_output_list, }) @@ -252,8 +251,8 @@ class Boolean(ComfyTypeIO): '''Boolean input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: bool=None, label_on: str=None, label_off: str=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.label_on = label_on self.label_off = label_off self.default: bool @@ -271,9 +270,9 @@ class Int(ComfyTypeIO): class Input(WidgetInput): '''Integer input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, - default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None, + display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step @@ -298,13 +297,15 @@ class Float(ComfyTypeIO): '''Float input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, - display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + display_mode: NumberDisplay=None, gradient_stops: list[dict]=None, + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.min = min self.max = max self.step = step self.round = round self.display_mode = display_mode + self.gradient_stops = gradient_stops self.default: float def as_dict(self): @@ -314,6 +315,7 @@ class Float(ComfyTypeIO): "step": self.step, "round": self.round, "display": self.display_mode, + "gradient_stops": self.gradient_stops, }) @comfytype(io_type="STRING") @@ -324,8 +326,8 @@ class String(ComfyTypeIO): '''String input.''' def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None, - socketless: bool=None, force_input: bool=None): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input) + socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) self.multiline = multiline self.placeholder = placeholder self.dynamic_prompts = dynamic_prompts @@ -353,17 +355,20 @@ class Combo(ComfyTypeIO): tooltip: str=None, lazy: bool=None, default: str | int | Enum = None, - control_after_generate: bool=None, + control_after_generate: bool | ControlAfterGenerate=None, upload: UploadType=None, image_folder: FolderType=None, remote: RemoteOptions=None, socketless: bool=None, + extra_dict=None, + raw_link: bool=None, + advanced: bool=None, ): if isinstance(options, type) and issubclass(options, Enum): options = [v.value for v in options] if isinstance(default, Enum): default = default.value - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) self.multiselect = False self.options = options self.control_after_generate = control_after_generate @@ -387,10 +392,6 @@ class Combo(ComfyTypeIO): super().__init__(id, display_name, tooltip, is_output_list) self.options = options if options is not None else [] - @property - def io_type(self): - return self.options - @comfytype(io_type="COMBO") class MultiCombo(ComfyTypeI): '''Multiselect Combo input (dropdown for selecting potentially more than one value).''' @@ -398,9 +399,9 @@ class MultiCombo(ComfyTypeI): Type = list[str] class Input(Combo.Input): def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, - default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, - socketless: bool=None): - super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless) + default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None, + socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced) self.multiselect = True self.placeholder = placeholder self.chip = chip @@ -433,9 +434,9 @@ class Webcam(ComfyTypeIO): Type = str def __init__( self, id: str, display_name: str=None, optional=False, - tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None + tooltip: str=None, lazy: bool=None, default: str=None, socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None ): - super().__init__(id, display_name, optional, tooltip, lazy, default, socketless) + super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, None, extra_dict, raw_link, advanced) @comfytype(io_type="MASK") @@ -656,7 +657,7 @@ class Video(ComfyTypeIO): @comfytype(io_type="SVG") class SVG(ComfyTypeIO): - Type = Any # TODO: SVG class is defined in comfy_extras/nodes_images.py, causing circular reference; should be moved to somewhere else before referenced directly in v3 + Type = _SVG @comfytype(io_type="LORA_MODEL") class LoraModel(ComfyTypeIO): @@ -676,6 +677,49 @@ class Voxel(ComfyTypeIO): class Mesh(ComfyTypeIO): Type = MESH + +@comfytype(io_type="FILE_3D") +class File3DAny(ComfyTypeIO): + """General 3D file type - accepts any supported 3D format.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_GLB") +class File3DGLB(ComfyTypeIO): + """GLB format 3D file - binary glTF, best for web and cross-platform.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_GLTF") +class File3DGLTF(ComfyTypeIO): + """GLTF format 3D file - JSON-based glTF with external resources.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_FBX") +class File3DFBX(ComfyTypeIO): + """FBX format 3D file - best for game engines and animation.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_OBJ") +class File3DOBJ(ComfyTypeIO): + """OBJ format 3D file - simple geometry format.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_STL") +class File3DSTL(ComfyTypeIO): + """STL format 3D file - best for 3D printing.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_USDZ") +class File3DUSDZ(ComfyTypeIO): + """USDZ format 3D file - Apple AR format.""" + Type = File3D + + @comfytype(io_type="HOOKS") class Hooks(ComfyTypeIO): if TYPE_CHECKING: @@ -763,7 +807,7 @@ class AnyType(ComfyTypeIO): Type = Any @comfytype(io_type="MODEL_PATCH") -class MODEL_PATCH(ComfyTypeIO): +class ModelPatch(ComfyTypeIO): Type = Any @comfytype(io_type="AUDIO_ENCODER") @@ -788,7 +832,7 @@ class MultiType: ''' Input that permits more than one input type; if `id` is an instance of `ComfyType.Input`, then that input will be used to create a widget (if applicable) with overridden values. ''' - def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + def __init__(self, id: str | Input, types: list[type[_ComfyType] | _ComfyType], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): # if id is an Input, then use that Input with overridden values self.input_override = None if isinstance(id, Input): @@ -801,7 +845,7 @@ class MultiType: # if is a widget input, make sure widget_type is set appropriately if isinstance(self.input_override, WidgetInput): self.input_override.widget_type = self.input_override.get_io_type() - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self._io_types = types @property @@ -855,8 +899,8 @@ class MatchType(ComfyTypeIO): class Input(Input): def __init__(self, id: str, template: MatchType.Template, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict, raw_link, advanced) self.template = template def as_dict(self): @@ -867,6 +911,8 @@ class MatchType(ComfyTypeIO): class Output(Output): def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, is_output_list=False): + if not id and not display_name: + display_name = "MATCHTYPE" super().__init__(id, display_name, tooltip, is_output_list) self.template = template @@ -879,24 +925,30 @@ class DynamicInput(Input, ABC): ''' Abstract class for dynamic input registration. ''' - def get_dynamic(self) -> list[Input]: - return [] - - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - pass + pass class DynamicOutput(Output, ABC): ''' Abstract class for dynamic output registration. ''' - def __init__(self, id: str=None, display_name: str=None, tooltip: str=None, - is_output_list=False): - super().__init__(id, display_name, tooltip, is_output_list) + pass - def get_dynamic(self) -> list[Output]: - return [] +def handle_prefix(prefix_list: list[str] | None, id: str | None = None) -> list[str]: + if prefix_list is None: + prefix_list = [] + if id is not None: + prefix_list = prefix_list + [id] + return prefix_list + +def finalize_prefix(prefix_list: list[str] | None, id: str | None = None) -> str: + assert not (prefix_list is None and id is None) + if prefix_list is None: + return id + elif id is not None: + prefix_list = prefix_list + [id] + return ".".join(prefix_list) @comfytype(io_type="COMFY_AUTOGROW_V3") class Autogrow(ComfyTypeI): @@ -933,14 +985,6 @@ class Autogrow(ComfyTypeI): def validate(self): self.input.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - real_inputs = [] - for name, input in self.cached_inputs.items(): - if name in live_inputs: - real_inputs.append(input) - add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, real_inputs, curr_prefix) - class TemplatePrefix(_AutogrowTemplate): def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): super().__init__(input) @@ -985,22 +1029,63 @@ class Autogrow(ComfyTypeI): "template": self.template.as_dict(), }) - def get_dynamic(self) -> list[Input]: - return self.template.get_all() - def get_all(self) -> list[Input]: return [self] + self.template.get_all() def validate(self): self.template.validate() - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - curr_prefix = f"{curr_prefix}{self.id}." - # need to remove self from expected inputs dictionary; replaced by template inputs in frontend - for inner_dict in d.values(): - if self.id in inner_dict: - del inner_dict[self.id] - self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + # NOTE: purposely do not include self in out_dict; instead use only the template inputs + # need to figure out names based on template type + is_names = ("names" in value[1]["template"]) + is_prefix = ("prefix" in value[1]["template"]) + input = value[1]["template"]["input"] + if is_names: + min = value[1]["template"]["min"] + names = value[1]["template"]["names"] + max = len(names) + elif is_prefix: + prefix = value[1]["template"]["prefix"] + min = value[1]["template"]["min"] + max = value[1]["template"]["max"] + names = [f"{prefix}{i}" for i in range(max)] + # need to create a new input based on the contents of input + template_input = None + template_required = True + for _input_type, dict_input in input.items(): + # for now, get just the first value from dict_input; if not required, min can be ignored + if len(dict_input) == 0: + continue + template_input = list(dict_input.values())[0] + template_required = _input_type == "required" + break + if template_input is None: + raise Exception("template_input could not be determined from required or optional; this should never happen.") + new_dict = {} + new_dict_added_to = False + # first, add possible inputs into out_dict + for i, name in enumerate(names): + expected_id = finalize_prefix(curr_prefix, name) + # required + if i < min and template_required: + out_dict["required"][expected_id] = template_input + type_dict = new_dict.setdefault("required", {}) + # optional + else: + out_dict["optional"][expected_id] = template_input + type_dict = new_dict.setdefault("optional", {}) + if expected_id in live_inputs: + # NOTE: prefix gets added in parse_class_inputs + type_dict[name] = template_input + new_dict_added_to = True + # account for the edge case that all inputs are optional and no values are received + if not new_dict_added_to: + finalized_prefix = finalize_prefix(curr_prefix) + out_dict["dynamic_paths"][finalized_prefix] = finalized_prefix + out_dict["dynamic_paths_default_value"][finalized_prefix] = DynamicPathsDefaultValue.EMPTY_DICT + parse_class_inputs(out_dict, live_inputs, new_dict, curr_prefix) @comfytype(io_type="COMFY_DYNAMICCOMBO_V3") class DynamicCombo(ComfyTypeI): @@ -1023,23 +1108,6 @@ class DynamicCombo(ComfyTypeI): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) self.options = options - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - # check if dynamic input's id is in live_inputs - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - key = live_inputs[self.id] - selected_option = None - for option in self.options: - if option.key == key: - selected_option = option - break - if selected_option is not None: - add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) - - def get_dynamic(self) -> list[Input]: - return [input for option in self.options for input in option.inputs] - def get_all(self) -> list[Input]: return [self] + [input for option in self.options for input in option.inputs] @@ -1054,6 +1122,24 @@ class DynamicCombo(ComfyTypeI): for input in option.inputs: input.validate() + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + key = live_inputs[finalized_id] + selected_option = None + # get options from dict + options: list[dict[str, str | dict[str, Any]]] = value[1]["options"] + for option in options: + if option["key"] == key: + selected_option = option + break + if selected_option is not None: + parse_class_inputs(out_dict, live_inputs, selected_option["inputs"], curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + @comfytype(io_type="COMFY_DYNAMICSLOT_V3") class DynamicSlot(ComfyTypeI): Type = dict[str, Any] @@ -1076,17 +1162,8 @@ class DynamicSlot(ComfyTypeI): self.force_input = True self.slot.force_input = True - def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): - if self.id in live_inputs: - curr_prefix = f"{curr_prefix}{self.id}." - add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) - add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) - - def get_dynamic(self) -> list[Input]: - return [self.slot] + self.inputs - def get_all(self) -> list[Input]: - return [self] + [self.slot] + self.inputs + return [self.slot] + self.inputs def as_dict(self): return super().as_dict() | prune_dict({ @@ -1100,17 +1177,122 @@ class DynamicSlot(ComfyTypeI): for input in self.inputs: input.validate() -def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): - dynamic = d.setdefault("dynamic_paths", {}) - if self is not None: - dynamic[self.id] = f"{curr_prefix}{self.id}" - for i in inputs: - if not isinstance(i, DynamicInput): - dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + @staticmethod + def _expand_schema_for_dynamic(out_dict: dict[str, Any], live_inputs: dict[str, Any], value: tuple[str, dict[str, Any]], input_type: str, curr_prefix: list[str] | None): + finalized_id = finalize_prefix(curr_prefix) + if finalized_id in live_inputs: + inputs = value[1]["inputs"] + parse_class_inputs(out_dict, live_inputs, inputs, curr_prefix) + # add self to inputs + out_dict[input_type][finalized_id] = value + out_dict["dynamic_paths"][finalized_id] = finalize_prefix(curr_prefix, curr_prefix[-1]) + +@comfytype(io_type="IMAGECOMPARE") +class ImageCompare(ComfyTypeI): + Type = dict + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, None, socketless, None, None, None, None, advanced) + + def as_dict(self): + return super().as_dict() + + +@comfytype(io_type="COLOR") +class Color(ComfyTypeIO): + Type = str + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, advanced: bool=None, default: str="#ffffff"): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + self.default: str + + def as_dict(self): + return super().as_dict() + +@comfytype(io_type="BOUNDING_BOX") +class BoundingBox(ComfyTypeIO): + class BoundingBoxDict(TypedDict): + x: int + y: int + width: int + height: int + Type = BoundingBoxDict + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: dict=None, component: str=None, force_input: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless) + self.component = component + self.force_input = force_input + if default is None: + self.default = {"x": 0, "y": 0, "width": 512, "height": 512} + + def as_dict(self): + d = super().as_dict() + if self.component: + d["component"] = self.component + if self.force_input is not None: + d["forceInput"] = self.force_input + return d + + +@comfytype(io_type="CURVE") +class Curve(ComfyTypeIO): + from comfy_api.input import CurvePoint + if TYPE_CHECKING: + Type = CurveInput_ + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = [(0.0, 0.0), (1.0, 1.0)] + + def as_dict(self): + d = super().as_dict() + if self.default is not None: + d["default"] = {"points": [list(p) for p in self.default], "interpolation": "monotone_cubic"} + return d + + +@comfytype(io_type="HISTOGRAM") +class Histogram(ComfyTypeIO): + """A histogram represented as a list of bin counts.""" + Type = list[int] + + +DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} +def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): + DYNAMIC_INPUT_LOOKUP[io_type] = func + +def get_dynamic_input_func(io_type: str) -> Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]: + return DYNAMIC_INPUT_LOOKUP[io_type] + +def setup_dynamic_input_funcs(): + # Autogrow.Input + register_dynamic_input_func(Autogrow.io_type, Autogrow._expand_schema_for_dynamic) + # DynamicCombo.Input + register_dynamic_input_func(DynamicCombo.io_type, DynamicCombo._expand_schema_for_dynamic) + # DynamicSlot.Input + register_dynamic_input_func(DynamicSlot.io_type, DynamicSlot._expand_schema_for_dynamic) + +if len(DYNAMIC_INPUT_LOOKUP) == 0: + setup_dynamic_input_funcs() class V3Data(TypedDict): hidden_inputs: dict[str, Any] + 'Dictionary where the keys are the hidden input ids and the values are the values of the hidden inputs.' dynamic_paths: dict[str, Any] + 'Dictionary where the keys are the input ids and the values dictate how to turn the inputs into a nested dictionary.' + dynamic_paths_default_value: dict[str, Any] + 'Dictionary where the keys are the input ids and the values are a string from DynamicPathsDefaultValue for the inputs if value is None.' + create_dynamic_tuple: bool + 'When True, the value of the dynamic input will be in the format (value, path_key).' class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -1146,6 +1328,10 @@ class HiddenHolder: api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None), ) + @classmethod + def from_v3_data(cls, v3_data: V3Data | None) -> HiddenHolder: + return cls.from_dict(v3_data["hidden_inputs"] if v3_data else None) + class Hidden(str, Enum): ''' Enumerator for requesting hidden variables in nodes. @@ -1168,6 +1354,7 @@ class Hidden(str, Enum): class NodeInfoV1: input: dict=None input_order: dict[str, list[str]]=None + is_input_list: bool=None output: list[str]=None output_is_list: list[bool]=None output_name: list[str]=None @@ -1181,21 +1368,75 @@ class NodeInfoV1: output_node: bool=None deprecated: bool=None experimental: bool=None + dev_only: bool=None api_node: bool=None + price_badge: dict | None = None + search_aliases: list[str]=None + essentials_category: str=None + @dataclass -class NodeInfoV3: - input: dict=None - output: dict=None - hidden: list[str]=None - name: str=None - display_name: str=None - description: str=None - category: str=None - output_node: bool=None - deprecated: bool=None - experimental: bool=None - api_node: bool=None +class PriceBadgeDepends: + widgets: list[str] = field(default_factory=list) + inputs: list[str] = field(default_factory=list) + input_groups: list[str] = field(default_factory=list) + + def validate(self) -> None: + if not isinstance(self.widgets, list) or any(not isinstance(x, str) for x in self.widgets): + raise ValueError("PriceBadgeDepends.widgets must be a list[str].") + if not isinstance(self.inputs, list) or any(not isinstance(x, str) for x in self.inputs): + raise ValueError("PriceBadgeDepends.inputs must be a list[str].") + if not isinstance(self.input_groups, list) or any(not isinstance(x, str) for x in self.input_groups): + raise ValueError("PriceBadgeDepends.input_groups must be a list[str].") + + def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]: + # Build lookup: widget_id -> io_type + input_types: dict[str, str] = {} + for inp in schema_inputs: + all_inputs = inp.get_all() + input_types[inp.id] = inp.get_io_type() # First input is always the parent itself + for nested_inp in all_inputs[1:]: + # For DynamicCombo/DynamicSlot, nested inputs are prefixed with parent ID + # to match frontend naming convention (e.g., "should_texture.enable_pbr") + prefixed_id = f"{inp.id}.{nested_inp.id}" + input_types[prefixed_id] = nested_inp.get_io_type() + + # Enrich widgets with type information, raising error for unknown widgets + widgets_data: list[dict[str, str]] = [] + for w in self.widgets: + if w not in input_types: + raise ValueError( + f"PriceBadge depends_on.widgets references unknown widget '{w}'. " + f"Available widgets: {list(input_types.keys())}" + ) + widgets_data.append({"name": w, "type": input_types[w]}) + + return { + "widgets": widgets_data, + "inputs": self.inputs, + "input_groups": self.input_groups, + } + + +@dataclass +class PriceBadge: + expr: str + depends_on: PriceBadgeDepends = field(default_factory=PriceBadgeDepends) + engine: str = field(default="jsonata") + + def validate(self) -> None: + if self.engine != "jsonata": + raise ValueError(f"Unsupported PriceBadge.engine '{self.engine}'. Only 'jsonata' is supported.") + if not isinstance(self.expr, str) or not self.expr.strip(): + raise ValueError("PriceBadge.expr must be a non-empty string.") + self.depends_on.validate() + + def as_dict(self, schema_inputs: list["Input"]) -> dict[str, Any]: + return { + "engine": self.engine, + "depends_on": self.depends_on.as_dict(schema_inputs), + "expr": self.expr, + } @dataclass @@ -1213,6 +1454,8 @@ class Schema: hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" + search_aliases: list[str] = field(default_factory=list) + """Alternative names for search. Useful for synonyms, abbreviations, or old names after renaming.""" is_input_list: bool = False """A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes. @@ -1239,73 +1482,78 @@ class Schema: """Flags a node as deprecated, indicating to users that they should find alternatives to this node.""" is_experimental: bool=False """Flags a node as experimental, informing users that it may change or not work as expected.""" + is_dev_only: bool=False + """Flags a node as dev-only, hiding it from search/menus unless dev mode is enabled.""" is_api_node: bool=False """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview.""" + price_badge: PriceBadge | None = None + """Optional client-evaluated pricing badge declaration for this node.""" not_idempotent: bool=False """Flags a node as not idempotent; when True, the node will run and not reuse the cached outputs when identical inputs are provided on a different node in the graph.""" enable_expand: bool=False """Flags a node as expandable, allowing NodeOutput to include 'expand' property.""" + accept_all_inputs: bool=False + """When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema.""" + essentials_category: str | None = None + """Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing').""" def validate(self): '''Validate the schema: - verify ids on inputs and outputs are unique - both internally and in relation to each other ''' nested_inputs: list[Input] = [] - if self.inputs is not None: - for input in self.inputs: + for input in self.inputs: + if not isinstance(input, DynamicInput): nested_inputs.extend(input.get_all()) - input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] - output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] + input_ids = [i.id for i in nested_inputs] + output_ids = [o.id for o in self.outputs] input_set = set(input_ids) output_set = set(output_ids) - issues = [] + issues: list[str] = [] # verify ids are unique per list if len(input_set) != len(input_ids): issues.append(f"Input ids must be unique, but {[item for item, count in Counter(input_ids).items() if count > 1]} are not.") if len(output_set) != len(output_ids): issues.append(f"Output ids must be unique, but {[item for item, count in Counter(output_ids).items() if count > 1]} are not.") - # verify ids are unique between lists - intersection = input_set & output_set - if len(intersection) > 0: - issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") if len(issues) > 0: raise ValueError("\n".join(issues)) # validate inputs and outputs - if self.inputs is not None: - for input in self.inputs: - input.validate() - if self.outputs is not None: - for output in self.outputs: - output.validate() + for input in self.inputs: + input.validate() + for output in self.outputs: + output.validate() + if self.price_badge is not None: + self.price_badge.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" + # ensure inputs, outputs, and hidden are lists + if self.inputs is None: + self.inputs = [] + if self.outputs is None: + self.outputs = [] + if self.hidden is None: + self.hidden = [] # if is an api_node, will need key-related hidden if self.is_api_node: - if self.hidden is None: - self.hidden = [] if Hidden.auth_token_comfy_org not in self.hidden: self.hidden.append(Hidden.auth_token_comfy_org) if Hidden.api_key_comfy_org not in self.hidden: self.hidden.append(Hidden.api_key_comfy_org) # if is an output_node, will need prompt and extra_pnginfo if self.is_output_node: - if self.hidden is None: - self.hidden = [] if Hidden.prompt not in self.hidden: self.hidden.append(Hidden.prompt) if Hidden.extra_pnginfo not in self.hidden: self.hidden.append(Hidden.extra_pnginfo) # give outputs without ids default ids - if self.outputs is not None: - for i, output in enumerate(self.outputs): - if output.id is None: - output.id = f"_{i}_{output.io_type}_" + for i, output in enumerate(self.outputs): + if output.id is None: + output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: - # NOTE: live_inputs will not be used anymore very soon and this will be done another way + def get_v1_info(self, cls) -> NodeInfoV1: # get V1 inputs - input = create_input_dict_v1(self.inputs, live_inputs) + input = create_input_dict_v1(self.inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1336,6 +1584,7 @@ class Schema: info = NodeInfoV1( input=input, input_order={key: list(value.keys()) for (key, value) in input.items()}, + is_input_list=self.is_input_list, output=output, output_is_list=output_is_list, output_name=output_name, @@ -1348,81 +1597,83 @@ class Schema: output_node=self.is_output_node, deprecated=self.is_deprecated, experimental=self.is_experimental, + dev_only=self.is_dev_only, api_node=self.is_api_node, - python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") + python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes"), + price_badge=self.price_badge.as_dict(self.inputs) if self.price_badge is not None else None, + search_aliases=self.search_aliases if self.search_aliases else None, + essentials_category=self.essentials_category, ) return info +def get_finalized_class_inputs(d: dict[str, Any], live_inputs: dict[str, Any], include_hidden=False) -> tuple[dict[str, Any], V3Data]: + out_dict = { + "required": {}, + "optional": {}, + "dynamic_paths": {}, + "dynamic_paths_default_value": {}, + } + d = d.copy() + # ignore hidden for parsing + hidden = d.pop("hidden", None) + parse_class_inputs(out_dict, live_inputs, d) + if hidden is not None and include_hidden: + out_dict["hidden"] = hidden + v3_data = {} + dynamic_paths = out_dict.pop("dynamic_paths", None) + if dynamic_paths is not None and len(dynamic_paths) > 0: + v3_data["dynamic_paths"] = dynamic_paths + # this list is used for autogrow, in the case all inputs are optional and no values are passed + dynamic_paths_default_value = out_dict.pop("dynamic_paths_default_value", None) + if dynamic_paths_default_value is not None and len(dynamic_paths_default_value) > 0: + v3_data["dynamic_paths_default_value"] = dynamic_paths_default_value + return out_dict, hidden, v3_data - def get_v3_info(self, cls) -> NodeInfoV3: - input_dict = {} - output_dict = {} - hidden_list = [] - # TODO: make sure dynamic types will be handled correctly - if self.inputs: - for input in self.inputs: - add_to_dict_v3(input, input_dict) - if self.outputs: - for output in self.outputs: - add_to_dict_v3(output, output_dict) - if self.hidden: - for hidden in self.hidden: - hidden_list.append(hidden.value) +def parse_class_inputs(out_dict: dict[str, Any], live_inputs: dict[str, Any], curr_dict: dict[str, Any], curr_prefix: list[str] | None=None) -> None: + for input_type, inner_d in curr_dict.items(): + for id, value in inner_d.items(): + io_type = value[0] + if io_type in DYNAMIC_INPUT_LOOKUP: + # dynamic inputs need to be handled with lookup functions + dynamic_input_func = get_dynamic_input_func(io_type) + new_prefix = handle_prefix(curr_prefix, id) + dynamic_input_func(out_dict, live_inputs, value, input_type, new_prefix) + else: + # non-dynamic inputs get directly transferred + finalized_id = finalize_prefix(curr_prefix, id) + out_dict[input_type][finalized_id] = value + if curr_prefix: + out_dict["dynamic_paths"][finalized_id] = finalized_id - info = NodeInfoV3( - input=input_dict, - output=output_dict, - hidden=hidden_list, - name=self.node_id, - display_name=self.display_name, - description=self.description, - category=self.category, - output_node=self.is_output_node, - deprecated=self.is_deprecated, - experimental=self.is_experimental, - api_node=self.is_api_node, - python_module=getattr(cls, "RELATIVE_PYTHON_MODULE", "nodes") - ) - return info - - -def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: +def create_input_dict_v1(inputs: list[Input]) -> dict: input = { "required": {} } - add_to_input_dict_v1(input, inputs, live_inputs) + for i in inputs: + add_to_dict_v1(i, input) return input -def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): - for i in inputs: - if isinstance(i, DynamicInput): - add_to_dict_v1(i, d) - if live_inputs is not None: - i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) - else: - add_to_dict_v1(i, d) - -def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): +def add_to_dict_v1(i: Input, d: dict): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key as_dict.pop("optional", None) - if dynamic_dict is None: - value = (i.get_io_type(), as_dict) - else: - value = (i.get_io_type(), as_dict, dynamic_dict) - d.setdefault(key, {})[i.id] = value + d.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) -def add_to_dict_v3(io: Input | Output, d: dict): - d[io.id] = (io.get_io_type(), io.as_dict()) +class DynamicPathsDefaultValue: + EMPTY_DICT = "empty_dict" def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): paths = v3_data.get("dynamic_paths", None) + default_value_dict = v3_data.get("dynamic_paths_default_value", {}) if paths is None: return values values = values.copy() + result = {} + create_tuple = v3_data.get("create_dynamic_tuple", False) + for key, path in paths.items(): parts = path.split(".") current = result @@ -1431,7 +1682,15 @@ def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): is_last = (i == len(parts) - 1) if is_last: - current[p] = values.pop(key, None) + value = values.pop(key, None) + if value is None: + # see if a default value was provided for this key + default_option = default_value_dict.get(key, None) + if default_option == DynamicPathsDefaultValue.EMPTY_DICT: + value = {} + if create_tuple: + value = (value, key) + current[p] = value else: current = current.setdefault(p, {}) @@ -1446,7 +1705,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): SCHEMA = None # filled in during execution - resources: Resources = None hidden: HiddenHolder = None @classmethod @@ -1493,7 +1751,6 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): return [name for name in kwargs if kwargs[name] is None] def __init__(self): - self.local_resources: ResourcesLocal = None self.__class__.VALIDATE_CLASS() @classmethod @@ -1561,15 +1818,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"] if v3_data else None) + type_clone.hidden = HiddenHolder.from_v3_data(v3_data) return type_clone - - @final - @classmethod - def GET_NODE_INFO_V3(cls) -> dict[str, Any]: - schema = cls.GET_SCHEMA() - info = schema.get_v3_info(cls) - return asdict(info) ############################################# # V1 Backwards Compatibility code #-------------------------------------------- @@ -1612,6 +1862,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls.GET_SCHEMA() return cls._DEPRECATED + _DEV_ONLY = None + @final + @classproperty + def DEV_ONLY(cls): # noqa + if cls._DEV_ONLY is None: + cls.GET_SCHEMA() + return cls._DEV_ONLY + _API_NODE = None @final @classproperty @@ -1676,21 +1934,20 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls.GET_SCHEMA() return cls._NOT_IDEMPOTENT + _ACCEPT_ALL_INPUTS = None + @final + @classproperty + def ACCEPT_ALL_INPUTS(cls): # noqa + if cls._ACCEPT_ALL_INPUTS is None: + cls.GET_SCHEMA() + return cls._ACCEPT_ALL_INPUTS + @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: + def INPUT_TYPES(cls) -> dict[str, dict]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls, live_inputs) - input = info.input - if not include_hidden: - input.pop("hidden", None) - if return_schema: - v3_data: V3Data = {} - dynamic = input.pop("dynamic_paths", None) - if dynamic is not None: - v3_data["dynamic_paths"] = dynamic - return input, schema, v3_data - return input + info = schema.get_v1_info(cls) + return info.input @final @classmethod @@ -1715,6 +1972,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls._EXPERIMENTAL = schema.is_experimental if cls._DEPRECATED is None: cls._DEPRECATED = schema.is_deprecated + if cls._DEV_ONLY is None: + cls._DEV_ONLY = schema.is_dev_only if cls._API_NODE is None: cls._API_NODE = schema.is_api_node if cls._OUTPUT_NODE is None: @@ -1723,6 +1982,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): cls._INPUT_IS_LIST = schema.is_input_list if cls._NOT_IDEMPOTENT is None: cls._NOT_IDEMPOTENT = schema.not_idempotent + if cls._ACCEPT_ALL_INPUTS is None: + cls._ACCEPT_ALL_INPUTS = schema.accept_all_inputs if cls._RETURN_TYPES is None: output = [] @@ -1809,7 +2070,7 @@ class NodeOutput(_NodeOutputInternal): return self.args if len(self.args) > 0 else None @classmethod - def from_dict(cls, data: dict[str, Any]) -> "NodeOutput": + def from_dict(cls, data: dict[str, Any]) -> NodeOutput: args = () ui = None expand = None @@ -1836,11 +2097,74 @@ class _UIOutput(ABC): ... +class InputMapOldId(TypedDict): + """Map an old node input to a new node input by ID.""" + new_id: str + old_id: str + +class InputMapSetValue(TypedDict): + """Set a specific value for a new node input.""" + new_id: str + set_value: Any + +InputMap = InputMapOldId | InputMapSetValue +""" +Input mapping for node replacement. Type is inferred by dictionary keys: +- {"new_id": str, "old_id": str} - maps old input to new input +- {"new_id": str, "set_value": Any} - sets a specific value for new input +""" + +class OutputMap(TypedDict): + """Map outputs of node replacement via indexes.""" + new_idx: int + old_idx: int + +class NodeReplace: + """ + Defines a possible node replacement, mapping inputs and outputs of the old node to the new node. + + Also supports assigning specific values to the input widgets of the new node. + + Args: + new_node_id: The class name of the new replacement node. + old_node_id: The class name of the deprecated node. + old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot + connected. The workflow JSON stores widget values by their relative position index, + not by ID. This list maps those positional indexes to input IDs, enabling the + replacement system to correctly identify widget values during node migration. + input_mapping: List of input mappings from old node to new node. + output_mapping: List of output mappings from old node to new node. + """ + def __init__(self, + new_node_id: str, + old_node_id: str, + old_widget_ids: list[str] | None=None, + input_mapping: list[InputMap] | None=None, + output_mapping: list[OutputMap] | None=None, + ): + self.new_node_id = new_node_id + self.old_node_id = old_node_id + self.old_widget_ids = old_widget_ids + self.input_mapping = input_mapping + self.output_mapping = output_mapping + + def as_dict(self): + """Create serializable representation of the node replacement.""" + return { + "new_node_id": self.new_node_id, + "old_node_id": self.old_node_id, + "old_widget_ids": self.old_widget_ids, + "input_mapping": list(self.input_mapping) if self.input_mapping else None, + "output_mapping": list(self.output_mapping) if self.output_mapping else None, + } + + __all__ = [ "FolderType", "UploadType", "RemoteOptions", "NumberDisplay", + "ControlAfterGenerate", "comfytype", "Custom", @@ -1870,6 +2194,7 @@ __all__ = [ "ControlNet", "Vae", "Model", + "ModelPatch", "ClipVision", "ClipVisionOutput", "AudioEncoder", @@ -1885,6 +2210,13 @@ __all__ = [ "LossMap", "Voxel", "Mesh", + "File3DAny", + "File3DGLB", + "File3DGLTF", + "File3DFBX", + "File3DOBJ", + "File3DSTL", + "File3DUSDZ", "Hooks", "HookKeyframes", "TimestepsRange", @@ -1902,19 +2234,25 @@ __all__ = [ "AnyType", "MultiType", "Tracks", + "Color", # Dynamic Types "MatchType", - # "DynamicCombo", - # "Autogrow", + "DynamicCombo", + "Autogrow", # Other classes "HiddenHolder", "Hidden", "NodeInfoV1", - "NodeInfoV3", "Schema", "ComfyNode", "NodeOutput", "add_to_dict_v1", - "add_to_dict_v3", "V3Data", + "ImageCompare", + "PriceBadgeDepends", + "PriceBadge", + "BoundingBox", + "Curve", + "Histogram", + "NodeReplace", ] diff --git a/comfy_api/latest/_resources.py b/comfy_api/latest/_resources.py deleted file mode 100644 index a6bdda972..000000000 --- a/comfy_api/latest/_resources.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations -import comfy.utils -import folder_paths -import logging -from abc import ABC, abstractmethod -from typing import Any -import torch - -class ResourceKey(ABC): - Type = Any - def __init__(self): - ... - -class TorchDictFolderFilename(ResourceKey): - '''Key for requesting a torch file via file_name from a folder category.''' - Type = dict[str, torch.Tensor] - def __init__(self, folder_name: str, file_name: str): - self.folder_name = folder_name - self.file_name = file_name - - def __hash__(self): - return hash((self.folder_name, self.file_name)) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, TorchDictFolderFilename): - return False - return self.folder_name == other.folder_name and self.file_name == other.file_name - - def __str__(self): - return f"{self.folder_name} -> {self.file_name}" - -class Resources(ABC): - def __init__(self): - ... - - @abstractmethod - def get(self, key: ResourceKey, default: Any=...) -> Any: - pass - -class ResourcesLocal(Resources): - def __init__(self): - super().__init__() - self.local_resources: dict[ResourceKey, Any] = {} - - def get(self, key: ResourceKey, default: Any=...) -> Any: - cached = self.local_resources.get(key, None) - if cached is not None: - logging.info(f"Using cached resource '{key}'") - return cached - logging.info(f"Loading resource '{key}'") - to_return = None - if isinstance(key, TorchDictFolderFilename): - if default is ...: - to_return = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise(key.folder_name, key.file_name), safe_load=True) - else: - full_path = folder_paths.get_full_path(key.folder_name, key.file_name) - if full_path is not None: - to_return = comfy.utils.load_torch_file(full_path, safe_load=True) - - if to_return is not None: - self.local_resources[key] = to_return - return to_return - if default is not ...: - return default - raise Exception(f"Unsupported resource key type: {type(key)}") - - -class _RESOURCES: - ResourceKey = ResourceKey - TorchDictFolderFilename = TorchDictFolderFilename - Resources = Resources - ResourcesLocal = ResourcesLocal diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index fc5431dda..115baf392 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,6 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents -from .geometry_types import VOXEL, MESH +from .geometry_types import VOXEL, MESH, File3D +from .image_types import SVG __all__ = [ # Utility Types @@ -8,4 +9,6 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "File3D", + "SVG", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index 385122778..b586fceb3 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -1,3 +1,8 @@ +import shutil +from io import BytesIO +from pathlib import Path +from typing import IO + import torch @@ -10,3 +15,75 @@ class MESH: def __init__(self, vertices: torch.Tensor, faces: torch.Tensor): self.vertices = vertices self.faces = faces + + +class File3D: + """Class representing a 3D file from a file path or binary stream. + + Supports both disk-backed (file path) and memory-backed (BytesIO) storage. + """ + + def __init__(self, source: str | IO[bytes], file_format: str = ""): + self._source = source + self._format = file_format or self._infer_format() + + def _infer_format(self) -> str: + if isinstance(self._source, str): + return Path(self._source).suffix.lstrip(".").lower() + return "" + + @property + def format(self) -> str: + return self._format + + @format.setter + def format(self, value: str) -> None: + self._format = value.lstrip(".").lower() if value else "" + + @property + def is_disk_backed(self) -> bool: + return isinstance(self._source, str) + + def get_source(self) -> str | IO[bytes]: + if isinstance(self._source, str): + return self._source + if hasattr(self._source, "seek"): + self._source.seek(0) + return self._source + + def get_data(self) -> BytesIO: + if isinstance(self._source, str): + with open(self._source, "rb") as f: + result = BytesIO(f.read()) + return result + if hasattr(self._source, "seek"): + self._source.seek(0) + if isinstance(self._source, BytesIO): + return self._source + return BytesIO(self._source.read()) + + def save_to(self, path: str) -> str: + dest = Path(path) + dest.parent.mkdir(parents=True, exist_ok=True) + + if isinstance(self._source, str): + if Path(self._source).resolve() != dest.resolve(): + shutil.copy2(self._source, dest) + else: + if hasattr(self._source, "seek"): + self._source.seek(0) + with open(dest, "wb") as f: + f.write(self._source.read()) + return str(dest) + + def get_bytes(self) -> bytes: + if isinstance(self._source, str): + return Path(self._source).read_bytes() + if hasattr(self._source, "seek"): + self._source.seek(0) + return self._source.read() + + def __repr__(self) -> str: + if isinstance(self._source, str): + return f"File3D(source={self._source!r}, format={self._format!r})" + return f"File3D(, format={self._format!r})" diff --git a/comfy_api/latest/_util/image_types.py b/comfy_api/latest/_util/image_types.py new file mode 100644 index 000000000..f031ed426 --- /dev/null +++ b/comfy_api/latest/_util/image_types.py @@ -0,0 +1,18 @@ +from io import BytesIO + + +class SVG: + """Stores SVG representations via a list of BytesIO objects.""" + + def __init__(self, data: list[BytesIO]): + self.data = data + + def combine(self, other: 'SVG') -> 'SVG': + return SVG(self.data + other.data) + + @staticmethod + def combine_all(svgs: list['SVG']) -> 'SVG': + all_svgs_list: list[BytesIO] = [] + for svg_item in svgs: + all_svgs_list.extend(svg_item.data) + return SVG(all_svgs_list) diff --git a/comfy_api_nodes/README.md b/comfy_api_nodes/README.md deleted file mode 100644 index f56d6c860..000000000 --- a/comfy_api_nodes/README.md +++ /dev/null @@ -1,65 +0,0 @@ -# ComfyUI API Nodes - -## Introduction - -Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview). - -## Development - -While developing, you should be testing against the Staging environment. To test against staging: - -**Install ComfyUI_frontend** - -Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication. - -> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication. - -```bash -python run main.py --comfy-api-base https://stagingapi.comfy.org -``` - -To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging. - -API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes. - -### Redocly Instructions - -**Tip** -When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet. - -Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging. - -```bash -# Download the OpenAPI file from staging server. -curl -o openapi.yaml https://stagingapi.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` - - -# Merging to Master - -Before merging to comfyanonymous/ComfyUI master, follow these steps: - -1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes. -1. Make sure the ComfyUI API is deployed to prod with your changes. -1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file. - -```bash -# Download the OpenAPI file from prod server. -curl -o openapi.yaml https://api.comfy.org/openapi - -# Filter out unneeded API definitions. -npm install -g @redocly/cli -redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components - -# Generate the pydantic datamodels for validation. -datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel - -``` diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py index ee2aa1ce6..46a583b5e 100644 --- a/comfy_api_nodes/apis/__init__.py +++ b/comfy_api_nodes/apis/__init__.py @@ -1197,12 +1197,6 @@ class KlingImageGenImageReferenceType(str, Enum): face = 'face' -class KlingImageGenModelName(str, Enum): - kling_v1 = 'kling-v1' - kling_v1_5 = 'kling-v1-5' - kling_v2 = 'kling-v2' - - class KlingImageGenerationsRequest(BaseModel): aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9' callback_url: Optional[AnyUrl] = Field( @@ -1218,7 +1212,7 @@ class KlingImageGenerationsRequest(BaseModel): 0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0 ) image_reference: Optional[KlingImageGenImageReferenceType] = None - model_name: Optional[KlingImageGenModelName] = 'kling-v1' + model_name: str = Field(...) n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9) negative_prompt: Optional[str] = Field( None, description='Negative text prompt', max_length=200 diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl.py similarity index 100% rename from comfy_api_nodes/apis/bfl_api.py rename to comfy_api_nodes/apis/bfl.py diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py new file mode 100644 index 000000000..8c496b56c --- /dev/null +++ b/comfy_api_nodes/apis/bria.py @@ -0,0 +1,99 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + + +class InputModerationSettings(TypedDict): + prompt_content_moderation: bool + visual_input_moderation: bool + visual_output_moderation: bool + + +class BriaEditImageRequest(BaseModel): + instruction: str | None = Field(...) + structured_instruction: str | None = Field( + ..., + description="Use this instead of instruction for precise, programmatic control.", + ) + images: list[str] = Field( + ..., + description="Required. Publicly available URL or Base64-encoded. Must contain exactly one item.", + ) + mask: str | None = Field( + None, + description="Mask image (black and white). Black areas will be preserved, white areas will be edited. " + "If omitted, the edit applies to the entire image. " + "The input image and the the input mask must be of the same size.", + ) + negative_prompt: str | None = Field(None) + guidance_scale: float = Field(...) + model_version: str = Field(...) + steps_num: int = Field(...) + seed: int = Field(...) + ip_signal: bool = Field( + False, + description="If true, returns a warning for potential IP content in the instruction.", + ) + prompt_content_moderation: bool = Field( + False, description="If true, returns 422 on instruction moderation failure." + ) + visual_input_content_moderation: bool = Field( + False, description="If true, returns 422 on images or mask moderation failure." + ) + visual_output_content_moderation: bool = Field( + False, description="If true, returns 422 on visual output moderation failure." + ) + + +class BriaRemoveBackgroundRequest(BaseModel): + image: str = Field(...) + sync: bool = Field(False) + visual_input_content_moderation: bool = Field( + False, description="If true, returns 422 on input image moderation failure." + ) + visual_output_content_moderation: bool = Field( + False, description="If true, returns 422 on visual output moderation failure." + ) + seed: int = Field(...) + + +class BriaStatusResponse(BaseModel): + request_id: str = Field(...) + status_url: str = Field(...) + warning: str | None = Field(None) + + +class BriaRemoveBackgroundResult(BaseModel): + image_url: str = Field(...) + + +class BriaRemoveBackgroundResponse(BaseModel): + status: str = Field(...) + result: BriaRemoveBackgroundResult | None = Field(None) + + +class BriaImageEditResult(BaseModel): + structured_prompt: str = Field(...) + image_url: str = Field(...) + + +class BriaImageEditResponse(BaseModel): + status: str = Field(...) + result: BriaImageEditResult | None = Field(None) + + +class BriaRemoveVideoBackgroundRequest(BaseModel): + video: str = Field(...) + background_color: str = Field(default="transparent", description="Background color for the output video.") + output_container_and_codec: str = Field(...) + preserve_audio: bool = Field(True) + seed: int = Field(...) + + +class BriaRemoveVideoBackgroundResult(BaseModel): + video_url: str = Field(...) + + +class BriaRemoveVideoBackgroundResponse(BaseModel): + status: str = Field(...) + result: BriaRemoveVideoBackgroundResult | None = Field(None) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance.py similarity index 88% rename from comfy_api_nodes/apis/bytedance_api.py rename to comfy_api_nodes/apis/bytedance.py index 77cd76f9b..18455396d 100644 --- a/comfy_api_nodes/apis/bytedance_api.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -10,18 +10,7 @@ class Text2ImageTaskCreationRequest(BaseModel): size: str | None = Field(None) seed: int | None = Field(0, ge=0, le=2147483647) guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) - - -class Image2ImageTaskCreationRequest(BaseModel): - model: str = Field(...) - prompt: str = Field(...) - response_format: str | None = Field("url") - image: str = Field(..., description="Base64 encoded string or image URL") - size: str | None = Field("adaptive") - seed: int | None = Field(..., ge=0, le=2147483647) - guidance_scale: float | None = Field(..., ge=1.0, le=10.0) - watermark: bool | None = Field(True) + watermark: bool | None = Field(False) class Seedream4Options(BaseModel): @@ -37,7 +26,8 @@ class Seedream4TaskCreationRequest(BaseModel): seed: int = Field(..., ge=0, le=2147483647) sequential_image_generation: str = Field("disabled") sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) - watermark: bool = Field(True) + watermark: bool = Field(False) + output_format: str | None = None class ImageTaskCreationResponse(BaseModel): @@ -65,11 +55,13 @@ class TaskImageContent(BaseModel): class Text2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent] = Field(..., min_length=1) + generate_audio: bool | None = Field(...) class Image2VideoTaskCreationRequest(BaseModel): model: str = Field(...) content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + generate_audio: bool | None = Field(...) class TaskCreationResponse(BaseModel): @@ -115,6 +107,7 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [ ("2496x1664 (3:2)", 2496, 1664), ("1664x2496 (2:3)", 1664, 2496), ("3024x1296 (21:9)", 3024, 1296), + ("3072x3072 (1:1)", 3072, 3072), ("4096x4096 (1:1)", 4096, 4096), ("Custom", None, None), ] @@ -141,4 +134,9 @@ VIDEO_TASKS_EXECUTION_TIME = { "720p": 65, "1080p": 100, }, + "seedance-1-5-pro-251215": { + "480p": 80, + "720p": 100, + "1080p": 150, + }, } diff --git a/comfy_api_nodes/apis/elevenlabs.py b/comfy_api_nodes/apis/elevenlabs.py new file mode 100644 index 000000000..e58450fdf --- /dev/null +++ b/comfy_api_nodes/apis/elevenlabs.py @@ -0,0 +1,88 @@ +from pydantic import BaseModel, Field + + +class SpeechToTextRequest(BaseModel): + model_id: str = Field(...) + cloud_storage_url: str = Field(...) + language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code") + tag_audio_events: bool | None = Field(None, description="Annotate sounds like (laughter) in transcript") + num_speakers: int | None = Field(None, description="Max speakers predicted") + timestamps_granularity: str = Field(default="word", description="Timing precision: none, word, or character") + diarize: bool | None = Field(None, description="Annotate which speaker is talking") + diarization_threshold: float | None = Field(None, description="Speaker separation sensitivity") + temperature: float | None = Field(None, description="Randomness control") + seed: int = Field(..., description="Seed for deterministic sampling") + + +class SpeechToTextWord(BaseModel): + text: str = Field(..., description="The word text") + type: str = Field(default="word", description="Type of text element (word, spacing, etc.)") + start: float | None = Field(None, description="Start time in seconds (when timestamps enabled)") + end: float | None = Field(None, description="End time in seconds (when timestamps enabled)") + speaker_id: str | None = Field(None, description="Speaker identifier when diarization is enabled") + logprob: float | None = Field(None, description="Log probability of the word") + + +class SpeechToTextResponse(BaseModel): + language_code: str = Field(..., description="Detected or specified language code") + language_probability: float | None = Field(None, description="Confidence of language detection") + text: str = Field(..., description="Full transcript text") + words: list[SpeechToTextWord] | None = Field(None, description="Word-level timing information") + + +class TextToSpeechVoiceSettings(BaseModel): + stability: float | None = Field(None, description="Voice stability") + similarity_boost: float | None = Field(None, description="Similarity boost") + style: float | None = Field(None, description="Style exaggeration") + use_speaker_boost: bool | None = Field(None, description="Boost similarity to original speaker") + speed: float | None = Field(None, description="Speech speed") + + +class TextToSpeechRequest(BaseModel): + text: str = Field(..., description="Text to convert to speech") + model_id: str = Field(..., description="Model ID for TTS") + language_code: str | None = Field(None, description="ISO-639-1 or ISO-639-3 language code") + voice_settings: TextToSpeechVoiceSettings | None = Field(None, description="Voice settings") + seed: int = Field(..., description="Seed for deterministic sampling") + apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off") + + +class TextToSoundEffectsRequest(BaseModel): + text: str = Field(..., description="Text prompt to convert into a sound effect") + duration_seconds: float = Field(..., description="Duration of generated sound in seconds") + prompt_influence: float = Field(..., description="How closely generation follows the prompt") + loop: bool | None = Field(None, description="Whether to create a smoothly looping sound effect") + + +class AddVoiceRequest(BaseModel): + name: str = Field(..., description="Name that identifies the voice") + remove_background_noise: bool = Field(..., description="Remove background noise from voice samples") + + +class AddVoiceResponse(BaseModel): + voice_id: str = Field(..., description="The newly created voice's unique identifier") + + +class SpeechToSpeechRequest(BaseModel): + model_id: str = Field(..., description="Model ID for speech-to-speech") + voice_settings: str = Field(..., description="JSON string of voice settings") + seed: int = Field(..., description="Seed for deterministic sampling") + remove_background_noise: bool = Field(..., description="Remove background noise from input audio") + + +class DialogueInput(BaseModel): + text: str = Field(..., description="Text content to convert to speech") + voice_id: str = Field(..., description="Voice identifier for this dialogue segment") + + +class DialogueSettings(BaseModel): + stability: float | None = Field(None, description="Voice stability (0-1)") + + +class TextToDialogueRequest(BaseModel): + inputs: list[DialogueInput] = Field(..., description="List of dialogue segments") + model_id: str = Field(..., description="Model ID for dialogue generation") + language_code: str | None = Field(None, description="ISO-639-1 language code") + settings: DialogueSettings | None = Field(None, description="Voice settings") + seed: int | None = Field(None, description="Seed for deterministic sampling") + apply_text_normalization: str | None = Field(None, description="Text normalization mode: auto, on, off") diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini.py similarity index 93% rename from comfy_api_nodes/apis/gemini_api.py rename to comfy_api_nodes/apis/gemini.py index f8edc38c9..22879fe18 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini.py @@ -67,6 +67,7 @@ class GeminiPart(BaseModel): inlineData: GeminiInlineData | None = Field(None) fileData: GeminiFileData | None = Field(None) text: str | None = Field(None) + thought: bool | None = Field(None) class GeminiTextPart(BaseModel): @@ -116,14 +117,26 @@ class GeminiGenerationConfig(BaseModel): topP: float | None = Field(None, ge=0.0, le=1.0) +class GeminiImageOutputOptions(BaseModel): + mimeType: str = Field("image/png") + compressionQuality: int | None = Field(None) + + class GeminiImageConfig(BaseModel): aspectRatio: str | None = Field(None) imageSize: str | None = Field(None) + imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions) + + +class GeminiThinkingConfig(BaseModel): + includeThoughts: bool | None = Field(None) + thinkingLevel: str = Field(...) class GeminiImageGenerationConfig(GeminiGenerationConfig): responseModalities: list[str] | None = Field(None) imageConfig: GeminiImageConfig | None = Field(None) + thinkingConfig: GeminiThinkingConfig | None = Field(None) class GeminiImageGenerateContentRequest(BaseModel): @@ -133,6 +146,7 @@ class GeminiImageGenerateContentRequest(BaseModel): systemInstruction: GeminiSystemInstructionContent | None = Field(None) tools: list[GeminiTool] | None = Field(None) videoMetadata: GeminiVideoMetadata | None = Field(None) + uploadImagesToStorage: bool = Field(True) class GeminiGenerateContentRequest(BaseModel): diff --git a/comfy_api_nodes/apis/grok.py b/comfy_api_nodes/apis/grok.py new file mode 100644 index 000000000..fbedb53e0 --- /dev/null +++ b/comfy_api_nodes/apis/grok.py @@ -0,0 +1,83 @@ +from pydantic import BaseModel, Field + + +class ImageGenerationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + aspect_ratio: str = Field(...) + n: int = Field(...) + seed: int = Field(...) + response_format: str = Field("url") + resolution: str = Field(...) + + +class InputUrlObject(BaseModel): + url: str = Field(...) + + +class ImageEditRequest(BaseModel): + model: str = Field(...) + images: list[InputUrlObject] = Field(...) + prompt: str = Field(...) + resolution: str = Field(...) + n: int = Field(...) + seed: int = Field(...) + response_format: str = Field("url") + aspect_ratio: str | None = Field(...) + + +class VideoGenerationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + image: InputUrlObject | None = Field(None) + reference_images: list[InputUrlObject] | None = Field(None) + duration: int = Field(...) + aspect_ratio: str | None = Field(...) + resolution: str = Field(...) + seed: int = Field(...) + + +class VideoExtensionRequest(BaseModel): + prompt: str = Field(...) + video: InputUrlObject = Field(...) + duration: int = Field(default=6) + model: str | None = Field(default=None) + + +class VideoEditRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + video: InputUrlObject = Field(...) + seed: int = Field(...) + + +class ImageResponseObject(BaseModel): + url: str | None = Field(None) + b64_json: str | None = Field(None) + revised_prompt: str | None = Field(None) + + +class UsageObject(BaseModel): + cost_in_usd_ticks: int | None = Field(None) + + +class ImageGenerationResponse(BaseModel): + data: list[ImageResponseObject] = Field(...) + usage: UsageObject | None = Field(None) + + +class VideoGenerationResponse(BaseModel): + request_id: str = Field(...) + + +class VideoResponseObject(BaseModel): + url: str = Field(...) + upsampled_prompt: str | None = Field(None) + duration: int = Field(...) + + +class VideoStatusResponse(BaseModel): + status: str | None = Field(None) + video: VideoResponseObject | None = Field(None) + model: str | None = Field(None) + usage: UsageObject | None = Field(None) diff --git a/comfy_api_nodes/apis/hitpaw.py b/comfy_api_nodes/apis/hitpaw.py new file mode 100644 index 000000000..b23c5d9eb --- /dev/null +++ b/comfy_api_nodes/apis/hitpaw.py @@ -0,0 +1,51 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + + +class InputVideoModel(TypedDict): + model: str + resolution: str + + +class ImageEnhanceTaskCreateRequest(BaseModel): + model_name: str = Field(...) + img_url: str = Field(...) + extension: str = Field(".png") + exif: bool = Field(False) + DPI: int | None = Field(None) + + +class VideoEnhanceTaskCreateRequest(BaseModel): + video_url: str = Field(...) + extension: str = Field(".mp4") + model_name: str | None = Field(...) + resolution: list[int] = Field(..., description="Target resolution [width, height]") + original_resolution: list[int] = Field(..., description="Original video resolution [width, height]") + + +class TaskCreateDataResponse(BaseModel): + job_id: str = Field(...) + consume_coins: int | None = Field(None) + + +class TaskStatusPollRequest(BaseModel): + job_id: str = Field(...) + + +class TaskCreateResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskCreateDataResponse | None = Field(None) + + +class TaskStatusDataResponse(BaseModel): + job_id: str = Field(...) + status: str = Field(...) + res_url: str = Field("") + + +class TaskStatusResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskStatusDataResponse = Field(...) diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py new file mode 100644 index 000000000..dad9bc2fa --- /dev/null +++ b/comfy_api_nodes/apis/hunyuan3d.py @@ -0,0 +1,96 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field, model_validator + + +class InputGenerateType(TypedDict): + generate_type: str + polygon_type: str + pbr: bool + + +class Hunyuan3DViewImage(BaseModel): + ViewType: str = Field(..., description="Valid values: back, left, right.") + ViewImageUrl: str = Field(...) + + +class To3DProTaskRequest(BaseModel): + Model: str = Field(...) + Prompt: str | None = Field(None) + ImageUrl: str | None = Field(None) + MultiViewImages: list[Hunyuan3DViewImage] | None = Field(None) + EnablePBR: bool | None = Field(...) + FaceCount: int | None = Field(...) + GenerateType: str | None = Field(...) + PolygonType: str | None = Field(...) + + +class RequestError(BaseModel): + Code: str = Field("") + Message: str = Field("") + + +class To3DProTaskCreateResponse(BaseModel): + JobId: str | None = Field(None) + Error: RequestError | None = Field(None) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "Response" in values and isinstance(values["Response"], dict): + return values["Response"] + return values + + +class ResultFile3D(BaseModel): + Type: str = Field(...) + Url: str = Field(...) + PreviewImageUrl: str = Field("") + + +class To3DProTaskResultResponse(BaseModel): + ErrorCode: str = Field("") + ErrorMessage: str = Field("") + ResultFile3Ds: list[ResultFile3D] = Field([]) + Status: str = Field(...) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "Response" in values and isinstance(values["Response"], dict): + return values["Response"] + return values + + +class To3DProTaskQueryRequest(BaseModel): + JobId: str = Field(...) + + +class TaskFile3DInput(BaseModel): + Type: str = Field(..., description="File type: GLB, OBJ, or FBX") + Url: str = Field(...) + + +class To3DUVTaskRequest(BaseModel): + File: TaskFile3DInput = Field(...) + + +class To3DPartTaskRequest(BaseModel): + File: TaskFile3DInput = Field(...) + + +class TextureEditImageInfo(BaseModel): + Url: str = Field(...) + + +class TextureEditTaskRequest(BaseModel): + File3D: TaskFile3DInput = Field(...) + Image: TextureEditImageInfo | None = Field(None) + Prompt: str | None = Field(None) + EnablePBR: bool | None = Field(None) + + +class SmartTopologyRequest(BaseModel): + File3D: TaskFile3DInput = Field(...) + PolygonType: str | None = Field(...) + FaceLevel: str | None = Field(...) diff --git a/comfy_api_nodes/apis/ideogram.py b/comfy_api_nodes/apis/ideogram.py new file mode 100644 index 000000000..737e18e3b --- /dev/null +++ b/comfy_api_nodes/apis/ideogram.py @@ -0,0 +1,292 @@ +from enum import Enum +from typing import Optional, List, Dict, Any, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel, StrictBytes + + +class IdeogramColorPalette1(BaseModel): + name: str = Field(..., description='Name of the preset color palette') + + +class Member(BaseModel): + color: Optional[str] = Field( + None, description='Hexadecimal color code', pattern='^#[0-9A-Fa-f]{6}$' + ) + weight: Optional[float] = Field( + None, description='Optional weight for the color (0-1)', ge=0.0, le=1.0 + ) + + +class IdeogramColorPalette2(BaseModel): + members: List[Member] = Field( + ..., description='Array of color definitions with optional weights' + ) + + +class IdeogramColorPalette( + RootModel[Union[IdeogramColorPalette1, IdeogramColorPalette2]] +): + root: Union[IdeogramColorPalette1, IdeogramColorPalette2] = Field( + ..., + description='A color palette specification that can either use a preset name or explicit color definitions with weights', + ) + + +class ImageRequest(BaseModel): + aspect_ratio: Optional[str] = Field( + None, + description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.", + ) + color_palette: Optional[Dict[str, Any]] = Field( + None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.' + ) + magic_prompt_option: Optional[str] = Field( + None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')." + ) + model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')") + negative_prompt: Optional[str] = Field( + None, + description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.', + ) + num_images: Optional[int] = Field( + 1, + description='Optional. Number of images to generate (1-8). Defaults to 1.', + ge=1, + le=8, + ) + prompt: str = Field( + ..., description='Required. The prompt to use to generate the image.' + ) + resolution: Optional[str] = Field( + None, + description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.", + ) + seed: Optional[int] = Field( + None, + description='Optional. A number between 0 and 2147483647.', + ge=0, + le=2147483647, + ) + style_type: Optional[str] = Field( + None, + description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.", + ) + + +class IdeogramGenerateRequest(BaseModel): + image_request: ImageRequest = Field( + ..., description='The image generation request parameters.' + ) + + +class Datum(BaseModel): + is_image_safe: Optional[bool] = Field( + None, description='Indicates whether the image is considered safe.' + ) + prompt: Optional[str] = Field( + None, description='The prompt used to generate this image.' + ) + resolution: Optional[str] = Field( + None, description="The resolution of the generated image (e.g., '1024x1024')." + ) + seed: Optional[int] = Field( + None, description='The seed value used for this generation.' + ) + style_type: Optional[str] = Field( + None, + description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').", + ) + url: Optional[str] = Field(None, description='URL to the generated image.') + + +class IdeogramGenerateResponse(BaseModel): + created: Optional[datetime] = Field( + None, description='Timestamp when the generation was created.' + ) + data: Optional[List[Datum]] = Field( + None, description='Array of generated image information.' + ) + + +class StyleCode(RootModel[str]): + root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$') + + +class Datum1(BaseModel): + is_image_safe: Optional[bool] = None + prompt: Optional[str] = None + resolution: Optional[str] = None + seed: Optional[int] = None + style_type: Optional[str] = None + url: Optional[str] = None + + +class IdeogramV3IdeogramResponse(BaseModel): + created: Optional[datetime] = None + data: Optional[List[Datum1]] = None + + +class RenderingSpeed1(str, Enum): + TURBO = 'TURBO' + DEFAULT = 'DEFAULT' + QUALITY = 'QUALITY' + + +class IdeogramV3ReframeRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + rendering_speed: Optional[RenderingSpeed1] = None + resolution: str + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class MagicPrompt(str, Enum): + AUTO = 'AUTO' + ON = 'ON' + OFF = 'OFF' + + +class StyleType(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + + +class IdeogramV3RemixRequest(BaseModel): + aspect_ratio: Optional[str] = None + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + image_weight: Optional[int] = Field(50, ge=1, le=100) + magic_prompt: Optional[MagicPrompt] = None + negative_prompt: Optional[str] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + resolution: Optional[str] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + style_type: Optional[StyleType] = None + + +class IdeogramV3ReplaceBackgroundRequest(BaseModel): + color_palette: Optional[Dict[str, Any]] = None + image: Optional[StrictBytes] = None + magic_prompt: Optional[MagicPrompt] = None + num_images: Optional[int] = Field(None, ge=1, le=8) + prompt: str + rendering_speed: Optional[RenderingSpeed1] = None + seed: Optional[int] = Field(None, ge=0, le=2147483647) + style_codes: Optional[List[str]] = None + style_reference_images: Optional[List[StrictBytes]] = None + + +class ColorPalette(BaseModel): + name: str = Field(..., description='Name of the color palette', examples=['PASTEL']) + + +class MagicPrompt2(str, Enum): + ON = 'ON' + OFF = 'OFF' + + +class StyleType1(str, Enum): + AUTO = 'AUTO' + GENERAL = 'GENERAL' + REALISTIC = 'REALISTIC' + DESIGN = 'DESIGN' + FICTION = 'FICTION' + + +class RenderingSpeed(str, Enum): + DEFAULT = 'DEFAULT' + TURBO = 'TURBO' + QUALITY = 'QUALITY' + + +class IdeogramV3EditRequest(BaseModel): + color_palette: Optional[IdeogramColorPalette] = None + image: Optional[StrictBytes] = Field( + None, + description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.', + ) + magic_prompt: Optional[str] = Field( + None, + description='Determine if MagicPrompt should be used in generating the request or not.', + ) + mask: Optional[StrictBytes] = Field( + None, + description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.', + ) + num_images: Optional[int] = Field( + None, description='The number of images to generate.' + ) + prompt: str = Field( + ..., description='The prompt used to describe the edited result.' + ) + rendering_speed: RenderingSpeed + seed: Optional[int] = Field( + None, description='Random seed. Set for reproducible generation.' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, + description='A list of 8 character hexadecimal codes representing the style of the image. Cannot be used in conjunction with style_reference_images or style_type.', + ) + style_reference_images: Optional[List[StrictBytes]] = Field( + None, + description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.', + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) + + +class IdeogramV3Request(BaseModel): + aspect_ratio: Optional[str] = Field( + None, description='Aspect ratio in format WxH', examples=['1x3'] + ) + color_palette: Optional[ColorPalette] = None + magic_prompt: Optional[MagicPrompt2] = Field( + None, description='Whether to enable magic prompt enhancement' + ) + negative_prompt: Optional[str] = Field( + None, description='Text prompt specifying what to avoid in the generation' + ) + num_images: Optional[int] = Field( + None, description='Number of images to generate', ge=1 + ) + prompt: str = Field(..., description='The text prompt for image generation') + rendering_speed: RenderingSpeed + resolution: Optional[str] = Field( + None, description='Image resolution in format WxH', examples=['1280x800'] + ) + seed: Optional[int] = Field( + None, description='Seed value for reproducible generation' + ) + style_codes: Optional[List[StyleCode]] = Field( + None, description='Array of style codes in hexadecimal format' + ) + style_reference_images: Optional[List[str]] = Field( + None, description='Array of reference image URLs or identifiers' + ) + style_type: Optional[StyleType1] = Field( + None, description='The type of style to apply' + ) + character_reference_images: Optional[List[str]] = Field( + None, + description='Generations with character reference are subject to the character reference pricing. A set of images to use as character references (maximum total size 10MB across all character references), currently only supports 1 character reference image. The images should be in JPEG, PNG or WebP format.' + ) + character_reference_images_mask: Optional[List[str]] = Field( + None, + description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' + ) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling.py similarity index 63% rename from comfy_api_nodes/apis/kling_api.py rename to comfy_api_nodes/apis/kling.py index 80a758466..fe0f97cb3 100644 --- a/comfy_api_nodes/apis/kling_api.py +++ b/comfy_api_nodes/apis/kling.py @@ -1,12 +1,22 @@ from pydantic import BaseModel, Field +class MultiPromptEntry(BaseModel): + index: int = Field(...) + prompt: str = Field(...) + duration: str = Field(...) + + class OmniProText2VideoRequest(BaseModel): model_name: str = Field(..., description="kling-video-o1") aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") duration: str = Field(..., description="'5' or '10'") prompt: str = Field(...) mode: str = Field("pro") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) + sound: str = Field(..., description="'on' or 'off'") class OmniParamImage(BaseModel): @@ -26,6 +36,10 @@ class OmniProFirstLastFrameRequest(BaseModel): duration: str = Field(..., description="'5' or '10'") prompt: str = Field(...) mode: str = Field("pro") + sound: str | None = Field(None, description="'on' or 'off'") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) class OmniProReferences2VideoRequest(BaseModel): @@ -38,6 +52,10 @@ class OmniProReferences2VideoRequest(BaseModel): duration: str | None = Field(..., description="From 3 to 10.") prompt: str = Field(...) mode: str = Field("pro") + sound: str | None = Field(None, description="'on' or 'off'") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) class TaskStatusVideoResult(BaseModel): @@ -54,6 +72,7 @@ class TaskStatusImageResult(BaseModel): class TaskStatusResults(BaseModel): videos: list[TaskStatusVideoResult] | None = Field(None) images: list[TaskStatusImageResult] | None = Field(None) + series_images: list[TaskStatusImageResult] | None = Field(None) class TaskStatusResponseData(BaseModel): @@ -77,28 +96,56 @@ class OmniImageParamImage(BaseModel): class OmniProImageRequest(BaseModel): - model_name: str = Field(..., description="kling-image-o1") - resolution: str = Field(..., description="'1k' or '2k'") + model_name: str = Field(...) + resolution: str = Field(...) aspect_ratio: str | None = Field(...) prompt: str = Field(...) mode: str = Field("pro") n: int | None = Field(1, le=9) image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) + result_type: str | None = Field(None, description="Set to 'series' for series generation") + series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series") class TextToVideoWithAudioRequest(BaseModel): - model_name: str = Field(..., description="kling-v2-6") + model_name: str = Field(...) aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") - duration: str = Field(..., description="'5' or '10'") - prompt: str = Field(...) + duration: str = Field(...) + prompt: str | None = Field(...) + negative_prompt: str | None = Field(None) mode: str = Field("pro") sound: str = Field(..., description="'on' or 'off'") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) class ImageToVideoWithAudioRequest(BaseModel): - model_name: str = Field(..., description="kling-v2-6") + model_name: str = Field(...) image: str = Field(...) - duration: str = Field(..., description="'5' or '10'") - prompt: str = Field(...) + image_tail: str | None = Field(None) + duration: str = Field(...) + prompt: str | None = Field(...) + negative_prompt: str | None = Field(None) mode: str = Field("pro") sound: str = Field(..., description="'on' or 'off'") + multi_shot: bool | None = Field(None) + multi_prompt: list[MultiPromptEntry] | None = Field(None) + shot_type: str | None = Field(None) + + +class KlingAvatarRequest(BaseModel): + image: str = Field(...) + sound_file: str = Field(...) + prompt: str | None = Field(None) + mode: str = Field(...) + + +class MotionControlRequest(BaseModel): + prompt: str = Field(...) + image_url: str = Field(...) + video_url: str = Field(...) + keep_original_sound: str = Field(...) + character_orientation: str = Field(...) + mode: str = Field(..., description="'pro' or 'std'") + model_name: str = Field(...) diff --git a/comfy_api_nodes/apis/luma_api.py b/comfy_api_nodes/apis/luma.py similarity index 100% rename from comfy_api_nodes/apis/luma_api.py rename to comfy_api_nodes/apis/luma.py diff --git a/comfy_api_nodes/apis/magnific.py b/comfy_api_nodes/apis/magnific.py new file mode 100644 index 000000000..b9f148def --- /dev/null +++ b/comfy_api_nodes/apis/magnific.py @@ -0,0 +1,122 @@ +from typing import TypedDict + +from pydantic import AliasChoices, BaseModel, Field, model_validator + + +class InputPortraitMode(TypedDict): + portrait_mode: str + portrait_style: str + portrait_beautifier: str + + +class InputAdvancedSettings(TypedDict): + advanced_settings: str + whites: int + blacks: int + brightness: int + contrast: int + saturation: int + engine: str + transfer_light_a: str + transfer_light_b: str + fixed_generation: bool + + +class InputSkinEnhancerMode(TypedDict): + mode: str + skin_detail: int + optimized_for: str + + +class ImageUpscalerCreativeRequest(BaseModel): + image: str = Field(...) + scale_factor: str = Field(...) + optimized_for: str = Field(...) + prompt: str | None = Field(None) + creativity: int = Field(...) + hdr: int = Field(...) + resemblance: int = Field(...) + fractality: int = Field(...) + engine: str = Field(...) + + +class ImageUpscalerPrecisionV2Request(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + ultra_detail: int = Field(...) + flavor: str = Field(...) + scale_factor: int = Field(...) + + +class ImageRelightAdvancedSettingsRequest(BaseModel): + whites: int = Field(...) + blacks: int = Field(...) + brightness: int = Field(...) + contrast: int = Field(...) + saturation: int = Field(...) + engine: str = Field(...) + transfer_light_a: str = Field(...) + transfer_light_b: str = Field(...) + fixed_generation: bool = Field(...) + + +class ImageRelightRequest(BaseModel): + image: str = Field(...) + prompt: str | None = Field(None) + transfer_light_from_reference_image: str | None = Field(None) + light_transfer_strength: int = Field(...) + interpolate_from_original: bool = Field(...) + change_background: bool = Field(...) + style: str = Field(...) + preserve_details: bool = Field(...) + advanced_settings: ImageRelightAdvancedSettingsRequest | None = Field(...) + + +class ImageStyleTransferRequest(BaseModel): + image: str = Field(...) + reference_image: str = Field(...) + prompt: str | None = Field(None) + style_strength: int = Field(...) + structure_strength: int = Field(...) + is_portrait: bool = Field(...) + portrait_style: str | None = Field(...) + portrait_beautifier: str | None = Field(...) + flavor: str = Field(...) + engine: str = Field(...) + fixed_generation: bool = Field(...) + + +class ImageSkinEnhancerCreativeRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + + +class ImageSkinEnhancerFaithfulRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + skin_detail: int = Field(...) + + +class ImageSkinEnhancerFlexibleRequest(BaseModel): + image: str = Field(...) + sharpen: int = Field(...) + smart_grain: int = Field(...) + optimized_for: str = Field(...) + + +class TaskResponse(BaseModel): + """Unified response model that handles both wrapped and unwrapped API responses.""" + + task_id: str = Field(...) + status: str = Field(validation_alias=AliasChoices("status", "task_status")) + generated: list[str] | None = Field(None) + + @model_validator(mode="before") + @classmethod + def unwrap_data(cls, values: dict) -> dict: + if "data" in values and isinstance(values["data"], dict): + return values["data"] + return values diff --git a/comfy_api_nodes/apis/meshy.py b/comfy_api_nodes/apis/meshy.py new file mode 100644 index 000000000..7d72e6e91 --- /dev/null +++ b/comfy_api_nodes/apis/meshy.py @@ -0,0 +1,165 @@ +from typing import TypedDict + +from pydantic import BaseModel, Field + +from comfy_api.latest import Input + + +class InputShouldRemesh(TypedDict): + should_remesh: str + topology: str + target_polycount: int + + +class InputShouldTexture(TypedDict): + should_texture: str + enable_pbr: bool + texture_prompt: str + texture_image: Input.Image | None + + +class MeshyTaskResponse(BaseModel): + result: str = Field(...) + + +class MeshyTextToModelRequest(BaseModel): + mode: str = Field("preview") + prompt: str = Field(..., max_length=600) + art_style: str = Field(..., description="'realistic' or 'sculpture'") + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + pose_mode: str = Field(...) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyRefineTask(BaseModel): + mode: str = Field("refine") + preview_task_id: str = Field(...) + enable_pbr: bool | None = Field(...) + texture_prompt: str | None = Field(...) + texture_image_url: str | None = Field(...) + ai_model: str = Field(...) + moderation: bool = Field(False) + + +class MeshyImageToModelRequest(BaseModel): + image_url: str = Field(...) + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + should_texture: bool = Field(...) + enable_pbr: bool | None = Field(...) + pose_mode: str = Field(...) + texture_prompt: str | None = Field(None, max_length=600) + texture_image_url: str | None = Field(None) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyMultiImageToModelRequest(BaseModel): + image_urls: list[str] = Field(...) + ai_model: str = Field(...) + topology: str | None = Field(..., description="'quad' or 'triangle'") + target_polycount: int | None = Field(..., ge=100, le=300000) + symmetry_mode: str = Field(..., description="'auto', 'off' or 'on'") + should_remesh: bool = Field( + True, + description="False returns the original mesh, ignoring topology and polycount.", + ) + should_texture: bool = Field(...) + enable_pbr: bool | None = Field(...) + pose_mode: str = Field(...) + texture_prompt: str | None = Field(None, max_length=600) + texture_image_url: str | None = Field(None) + seed: int = Field(...) + moderation: bool = Field(False) + + +class MeshyRiggingRequest(BaseModel): + input_task_id: str = Field(...) + height_meters: float = Field(...) + texture_image_url: str | None = Field(...) + + +class MeshyAnimationRequest(BaseModel): + rig_task_id: str = Field(...) + action_id: int = Field(...) + + +class MeshyTextureRequest(BaseModel): + input_task_id: str = Field(...) + ai_model: str = Field(...) + enable_original_uv: bool = Field(...) + enable_pbr: bool = Field(...) + text_style_prompt: str | None = Field(...) + image_style_url: str | None = Field(...) + + +class MeshyModelsUrls(BaseModel): + glb: str = Field("") + fbx: str = Field("") + usdz: str = Field("") + obj: str = Field("") + + +class MeshyRiggedModelsUrls(BaseModel): + rigged_character_glb_url: str = Field("") + rigged_character_fbx_url: str = Field("") + + +class MeshyAnimatedModelsUrls(BaseModel): + animation_glb_url: str = Field("") + animation_fbx_url: str = Field("") + + +class MeshyResultTextureUrls(BaseModel): + base_color: str = Field(...) + metallic: str | None = Field(None) + normal: str | None = Field(None) + roughness: str | None = Field(None) + + +class MeshyTaskError(BaseModel): + message: str | None = Field(None) + + +class MeshyModelResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + model_urls: MeshyModelsUrls = Field(MeshyModelsUrls()) + thumbnail_url: str = Field(...) + video_url: str | None = Field(None) + status: str = Field(...) + progress: int = Field(0) + texture_urls: list[MeshyResultTextureUrls] | None = Field([]) + task_error: MeshyTaskError | None = Field(None) + + +class MeshyRiggedResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + status: str = Field(...) + progress: int = Field(0) + result: MeshyRiggedModelsUrls = Field(MeshyRiggedModelsUrls()) + task_error: MeshyTaskError | None = Field(None) + + +class MeshyAnimationResult(BaseModel): + id: str = Field(...) + type: str = Field(...) + status: str = Field(...) + progress: int = Field(0) + result: MeshyAnimatedModelsUrls = Field(MeshyAnimatedModelsUrls()) + task_error: MeshyTaskError | None = Field(None) diff --git a/comfy_api_nodes/apis/minimax_api.py b/comfy_api_nodes/apis/minimax.py similarity index 100% rename from comfy_api_nodes/apis/minimax_api.py rename to comfy_api_nodes/apis/minimax.py diff --git a/comfy_api_nodes/apis/moonvalley.py b/comfy_api_nodes/apis/moonvalley.py new file mode 100644 index 000000000..7ec7a4ade --- /dev/null +++ b/comfy_api_nodes/apis/moonvalley.py @@ -0,0 +1,152 @@ +from enum import Enum +from typing import Optional, Dict, Any + +from pydantic import BaseModel, Field, StrictBytes + + +class MoonvalleyPromptResponse(BaseModel): + error: Optional[Dict[str, Any]] = None + frame_conditioning: Optional[Dict[str, Any]] = None + id: Optional[str] = None + inference_params: Optional[Dict[str, Any]] = None + meta: Optional[Dict[str, Any]] = None + model_params: Optional[Dict[str, Any]] = None + output_url: Optional[str] = None + prompt_text: Optional[str] = None + status: Optional[str] = None + + +class MoonvalleyTextToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 75, description='Number of cooldown steps (calculated based on num_frames)' + ) + fps: Optional[int] = Field( + 24, description='Frames per second of the generated video' + ) + guidance_scale: Optional[float] = Field( + 10, description='Guidance scale for generation control' + ) + height: Optional[int] = Field( + 1080, description='Height of the generated video in pixels' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + num_frames: Optional[int] = Field(64, description='Number of frames to generate') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 0, description='Number of warmup steps (calculated based on num_frames)' + ) + width: Optional[int] = Field( + 1920, description='Width of the generated video in pixels' + ) + + +class MoonvalleyTextToVideoRequest(BaseModel): + image_url: Optional[str] = None + inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None + prompt_text: Optional[str] = None + webhook_url: Optional[str] = None + + +class MoonvalleyUploadFileRequest(BaseModel): + file: Optional[StrictBytes] = None + + +class MoonvalleyUploadFileResponse(BaseModel): + access_url: Optional[str] = None + + +class MoonvalleyVideoToVideoInferenceParams(BaseModel): + add_quality_guidance: Optional[bool] = Field( + True, description='Whether to add quality guidance' + ) + caching_coefficient: Optional[float] = Field( + 0.3, description='Caching coefficient for optimization' + ) + caching_cooldown: Optional[int] = Field( + 3, description='Number of caching cooldown steps' + ) + caching_warmup: Optional[int] = Field( + 3, description='Number of caching warmup steps' + ) + clip_value: Optional[float] = Field( + 3, description='CLIP value for generation control' + ) + conditioning_frame_index: Optional[int] = Field( + 0, description='Index of the conditioning frame' + ) + cooldown_steps: Optional[int] = Field( + 36, description='Number of cooldown steps (calculated based on num_frames)' + ) + guidance_scale: Optional[float] = Field( + 15, description='Guidance scale for generation control' + ) + negative_prompt: Optional[str] = Field(None, description='Negative prompt text') + seed: Optional[int] = Field( + None, description='Random seed for generation (default: random)' + ) + shift_value: Optional[float] = Field( + 3, description='Shift value for generation control' + ) + steps: Optional[int] = Field(80, description='Number of denoising steps') + use_guidance_schedule: Optional[bool] = Field( + True, description='Whether to use guidance scheduling' + ) + use_negative_prompts: Optional[bool] = Field( + False, description='Whether to use negative prompts' + ) + use_timestep_transform: Optional[bool] = Field( + True, description='Whether to use timestep transformation' + ) + warmup_steps: Optional[int] = Field( + 24, description='Number of warmup steps (calculated based on num_frames)' + ) + + +class ControlType(str, Enum): + motion_control = 'motion_control' + pose_control = 'pose_control' + + +class MoonvalleyVideoToVideoRequest(BaseModel): + control_type: ControlType = Field( + ..., description='Supported types for video control' + ) + inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None + prompt_text: str = Field(..., description='Describes the video to generate') + video_url: str = Field(..., description='Url to control video') + webhook_url: Optional[str] = Field( + None, description='Optional webhook URL for notifications' + ) diff --git a/comfy_api_nodes/apis/openai.py b/comfy_api_nodes/apis/openai.py new file mode 100644 index 000000000..b85ef252b --- /dev/null +++ b/comfy_api_nodes/apis/openai.py @@ -0,0 +1,170 @@ +from pydantic import BaseModel, Field + + +class Datum2(BaseModel): + b64_json: str | None = Field(None, description="Base64 encoded image data") + revised_prompt: str | None = Field(None, description="Revised prompt") + url: str | None = Field(None, description="URL of the image") + + +class InputTokensDetails(BaseModel): + image_tokens: int | None = Field(None) + text_tokens: int | None = Field(None) + + +class Usage(BaseModel): + input_tokens: int | None = Field(None) + input_tokens_details: InputTokensDetails | None = Field(None) + output_tokens: int | None = Field(None) + total_tokens: int | None = Field(None) + + +class OpenAIImageGenerationResponse(BaseModel): + data: list[Datum2] | None = Field(None) + usage: Usage | None = Field(None) + + +class OpenAIImageEditRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str = Field(...) + moderation: str | None = Field(None) + n: int | None = Field(None, description="The number of images to generate") + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + size: str | None = Field(None, description="Size of the output image") + + +class OpenAIImageGenerationRequest(BaseModel): + background: str | None = Field(None, description="Background transparency") + model: str | None = Field(None) + moderation: str | None = Field(None) + n: int | None = Field( + None, + description="The number of images to generate.", + ) + output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") + output_format: str | None = Field(None) + prompt: str = Field(...) + quality: str | None = Field(None, description="The quality of the generated image") + size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") + style: str | None = Field(None, description="Style of the image (only for dall-e-3)") + + +class ModelResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0) + top_p: float | None = Field( + 1, + description="Controls diversity of the response via nucleus sampling", + ge=0.0, + le=1.0, + ) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseProperties(BaseModel): + instructions: str | None = Field(None) + max_output_tokens: int | None = Field(None) + model: str | None = Field(None) + previous_response_id: str | None = Field(None) + truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") + + +class ResponseError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class OutputTokensDetails(BaseModel): + reasoning_tokens: int = Field(..., description="The number of reasoning tokens.") + + +class CachedTokensDetails(BaseModel): + cached_tokens: int = Field( + ..., + description="The number of tokens that were retrieved from the cache.", + ) + + +class ResponseUsage(BaseModel): + input_tokens: int = Field(..., description="The number of input tokens.") + input_tokens_details: CachedTokensDetails = Field(...) + output_tokens: int = Field(..., description="The number of output tokens.") + output_tokens_details: OutputTokensDetails = Field(...) + total_tokens: int = Field(..., description="The total number of tokens used.") + + +class InputTextContent(BaseModel): + text: str = Field(..., description="The text input to the model.") + type: str = Field("input_text") + + +class OutputContent(BaseModel): + type: str = Field(..., description="The type of output content") + text: str | None = Field(None, description="The text content") + data: str | None = Field(None, description="Base64-encoded audio data") + transcript: str | None = Field(None, description="Transcript of the audio") + + +class OutputMessage(BaseModel): + type: str = Field(..., description="The type of output item") + content: list[OutputContent] | None = Field(None, description="The content of the message") + role: str | None = Field(None, description="The role of the message") + + +class OpenAIResponse(ModelResponseProperties, ResponseProperties): + created_at: float | None = Field( + None, + description="Unix timestamp (in seconds) of when this Response was created.", + ) + error: ResponseError | None = Field(None) + id: str | None = Field(None, description="Unique identifier for this Response.") + object: str | None = Field(None, description="The object type of this resource - always set to `response`.") + output: list[OutputMessage] | None = Field(None) + parallel_tool_calls: bool | None = Field(True) + status: str | None = Field( + None, + description="One of `completed`, `failed`, `in_progress`, or `incomplete`.", + ) + usage: ResponseUsage | None = Field(None) + + +class InputImageContent(BaseModel): + detail: str = Field(..., description="One of `high`, `low`, or `auto`. Defaults to `auto`.") + file_id: str | None = Field(None) + image_url: str | None = Field(None) + type: str = Field(..., description="The type of the input item. Always `input_image`.") + + +class InputFileContent(BaseModel): + file_data: str | None = Field(None) + file_id: str | None = Field(None) + filename: str | None = Field(None, description="The name of the file to be sent to the model.") + type: str = Field(..., description="The type of the input item. Always `input_file`.") + + +class InputMessage(BaseModel): + content: list[InputTextContent | InputImageContent | InputFileContent] = Field( + ..., + description="A list of one or many input items to the model, containing different content types.", + ) + role: str | None = Field(None) + type: str | None = Field(None) + + +class OpenAICreateResponse(ModelResponseProperties, ResponseProperties): + include: str | None = Field(None) + input: list[InputMessage] = Field(...) + parallel_tool_calls: bool | None = Field( + True, description="Whether to allow the model to run tool calls in parallel." + ) + store: bool | None = Field( + True, + description="Whether to store the generated model response for later retrieval via API.", + ) + stream: bool | None = Field(False) + usage: ResponseUsage | None = Field(None) diff --git a/comfy_api_nodes/apis/openai_api.py b/comfy_api_nodes/apis/openai_api.py deleted file mode 100644 index ae5bb2673..000000000 --- a/comfy_api_nodes/apis/openai_api.py +++ /dev/null @@ -1,52 +0,0 @@ -from pydantic import BaseModel, Field - - -class Datum2(BaseModel): - b64_json: str | None = Field(None, description="Base64 encoded image data") - revised_prompt: str | None = Field(None, description="Revised prompt") - url: str | None = Field(None, description="URL of the image") - - -class InputTokensDetails(BaseModel): - image_tokens: int | None = None - text_tokens: int | None = None - - -class Usage(BaseModel): - input_tokens: int | None = None - input_tokens_details: InputTokensDetails | None = None - output_tokens: int | None = None - total_tokens: int | None = None - - -class OpenAIImageGenerationResponse(BaseModel): - data: list[Datum2] | None = None - usage: Usage | None = None - - -class OpenAIImageEditRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str = Field(...) - moderation: str | None = Field(None) - n: int | None = Field(None, description="The number of images to generate") - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - size: str | None = Field(None, description="Size of the output image") - - -class OpenAIImageGenerationRequest(BaseModel): - background: str | None = Field(None, description="Background transparency") - model: str | None = Field(None) - moderation: str | None = Field(None) - n: int | None = Field( - None, - description="The number of images to generate.", - ) - output_compression: int | None = Field(None, description="Compression level for JPEG or WebP (0-100)") - output_format: str | None = Field(None) - prompt: str = Field(...) - quality: str | None = Field(None, description="The quality of the generated image") - size: str | None = Field(None, description="Size of the image (e.g., 1024x1024, 1536x1024, auto)") - style: str | None = Field(None, description="Style of the image (only for dall-e-3)") diff --git a/comfy_api_nodes/apis/pixverse_api.py b/comfy_api_nodes/apis/pixverse.py similarity index 100% rename from comfy_api_nodes/apis/pixverse_api.py rename to comfy_api_nodes/apis/pixverse.py diff --git a/comfy_api_nodes/apis/quiver.py b/comfy_api_nodes/apis/quiver.py new file mode 100644 index 000000000..bc8708754 --- /dev/null +++ b/comfy_api_nodes/apis/quiver.py @@ -0,0 +1,43 @@ +from pydantic import BaseModel, Field + + +class QuiverImageObject(BaseModel): + url: str = Field(...) + + +class QuiverTextToSVGRequest(BaseModel): + model: str = Field(default="arrow-preview") + prompt: str = Field(...) + instructions: str | None = Field(default=None) + references: list[QuiverImageObject] | None = Field(default=None, max_length=4) + temperature: float | None = Field(default=None, ge=0, le=2) + top_p: float | None = Field(default=None, ge=0, le=1) + presence_penalty: float | None = Field(default=None, ge=-2, le=2) + + +class QuiverImageToSVGRequest(BaseModel): + model: str = Field(default="arrow-preview") + image: QuiverImageObject = Field(...) + auto_crop: bool | None = Field(default=None) + target_size: int | None = Field(default=None, ge=128, le=4096) + temperature: float | None = Field(default=None, ge=0, le=2) + top_p: float | None = Field(default=None, ge=0, le=1) + presence_penalty: float | None = Field(default=None, ge=-2, le=2) + + +class QuiverSVGResponseItem(BaseModel): + svg: str = Field(...) + mime_type: str | None = Field(default="image/svg+xml") + + +class QuiverSVGUsage(BaseModel): + total_tokens: int | None = Field(default=None) + input_tokens: int | None = Field(default=None) + output_tokens: int | None = Field(default=None) + + +class QuiverSVGResponse(BaseModel): + id: str | None = Field(default=None) + created: int | None = Field(default=None) + data: list[QuiverSVGResponseItem] = Field(...) + usage: QuiverSVGUsage | None = Field(default=None) diff --git a/comfy_api_nodes/apis/recraft_api.py b/comfy_api_nodes/apis/recraft.py similarity index 70% rename from comfy_api_nodes/apis/recraft_api.py rename to comfy_api_nodes/apis/recraft.py index c36d95f24..78ededd94 100644 --- a/comfy_api_nodes/apis/recraft_api.py +++ b/comfy_api_nodes/apis/recraft.py @@ -1,11 +1,8 @@ from __future__ import annotations - - from enum import Enum -from typing import Optional -from pydantic import BaseModel, Field, conint, confloat +from pydantic import BaseModel, Field class RecraftColor: @@ -201,11 +198,6 @@ dict_recraft_substyles_v3 = { } -class RecraftModel(str, Enum): - recraftv3 = 'recraftv3' - recraftv2 = 'recraftv2' - - class RecraftImageSize(str, Enum): res_1024x1024 = '1024x1024' res_1365x1024 = '1365x1024' @@ -224,30 +216,64 @@ class RecraftImageSize(str, Enum): res_1707x1024 = '1707x1024' +RECRAFT_V4_SIZES = [ + "1024x1024", + "1536x768", + "768x1536", + "1280x832", + "832x1280", + "1216x896", + "896x1216", + "1152x896", + "896x1152", + "832x1344", + "1280x896", + "896x1280", + "1344x768", + "768x1344", +] + +RECRAFT_V4_PRO_SIZES = [ + "2048x2048", + "3072x1536", + "1536x3072", + "2560x1664", + "1664x2560", + "2432x1792", + "1792x2432", + "2304x1792", + "1792x2304", + "1664x2688", + "1434x1024", + "1024x1434", + "2560x1792", + "1792x2560", +] + + class RecraftColorObject(BaseModel): rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model') class RecraftControlsObject(BaseModel): - colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors') - background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color') - no_text: Optional[bool] = Field(None, description='Do not embed text layouts') - artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].') + colors: list[RecraftColorObject] | None = Field(None, description='An array of preferable colors') + background_color: RecraftColorObject | None = Field(None, description='Use given color as a desired background color') + no_text: bool | None = Field(None, description='Do not embed text layouts') + artistic_level: int | None = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].') class RecraftImageGenerationRequest(BaseModel): prompt: str = Field(..., description='The text prompt describing the image to generate') - size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")') - n: conint(ge=1, le=6) = Field(..., description='The number of images to generate') - negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image') - model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")') - style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")') - substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input') - controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process') - style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID') - strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity') - random_seed: Optional[int] = Field(None, description="Seed for video generation") - # text_layout + size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")') + n: int = Field(..., description='The number of images to generate') + negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image') + model: str = Field(...) + style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")') + substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input') + controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process') + style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID') + strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity') + random_seed: int | None = Field(None, description="Seed for video generation") class RecraftReturnedObject(BaseModel): @@ -258,5 +284,13 @@ class RecraftReturnedObject(BaseModel): class RecraftImageGenerationResponse(BaseModel): created: int = Field(..., description='Unix timestamp when the generation was created') credits: int = Field(..., description='Number of credits used for the generation') - data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information') - image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image') + data: list[RecraftReturnedObject] | None = Field(None, description='Array of generated image information') + image: RecraftReturnedObject | None = Field(None, description='Single generated image') + + +class RecraftCreateStyleRequest(BaseModel): + style: str = Field(..., description="realistic_image, digital_illustration, vector_illustration, or icon") + + +class RecraftCreateStyleResponse(BaseModel): + id: str = Field(..., description="UUID of the created style") diff --git a/comfy_api_nodes/apis/reve.py b/comfy_api_nodes/apis/reve.py new file mode 100644 index 000000000..c6b5a69d8 --- /dev/null +++ b/comfy_api_nodes/apis/reve.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel, Field + + +class RevePostprocessingOperation(BaseModel): + process: str = Field(..., description="The postprocessing operation: upscale or remove_background.") + upscale_factor: int | None = Field( + None, + description="Upscale factor (2, 3, or 4). Only used when process is upscale.", + ge=2, + le=4, + ) + + +class ReveImageCreateRequest(BaseModel): + prompt: str = Field(...) + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageEditRequest(BaseModel): + edit_instruction: str = Field(...) + reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageRemixRequest(BaseModel): + prompt: str = Field(...) + reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.") + aspect_ratio: str | None = Field(...) + version: str = Field(...) + test_time_scaling: int | None = Field( + ..., + description="If included, the model will spend more effort making better images. Values between 1 and 15.", + ge=1, + le=15, + ) + postprocessing: list[RevePostprocessingOperation] | None = Field( + None, description="Optional postprocessing operations to apply after generation." + ) + + +class ReveImageResponse(BaseModel): + image: str | None = Field(None, description="The base64 encoded image data.") + request_id: str | None = Field(None, description="A unique id for the request.") + credits_used: float | None = Field(None, description="The number of credits used for this request.") + version: str | None = Field(None, description="The specific model version used.") + content_violation: bool | None = Field( + None, description="Indicates whether the generated image violates the content policy." + ) diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin.py similarity index 100% rename from comfy_api_nodes/apis/rodin_api.py rename to comfy_api_nodes/apis/rodin.py diff --git a/comfy_api_nodes/apis/runway.py b/comfy_api_nodes/apis/runway.py new file mode 100644 index 000000000..df6f2b845 --- /dev/null +++ b/comfy_api_nodes/apis/runway.py @@ -0,0 +1,127 @@ +from enum import Enum +from typing import Optional, List, Union +from datetime import datetime + +from pydantic import BaseModel, Field, RootModel + + +class RunwayAspectRatioEnum(str, Enum): + field_1280_720 = '1280:720' + field_720_1280 = '720:1280' + field_1104_832 = '1104:832' + field_832_1104 = '832:1104' + field_960_960 = '960:960' + field_1584_672 = '1584:672' + field_1280_768 = '1280:768' + field_768_1280 = '768:1280' + + +class Position(str, Enum): + first = 'first' + last = 'last' + + +class RunwayPromptImageDetailedObject(BaseModel): + position: Position = Field( + ..., + description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.", + ) + uri: str = Field( + ..., description='A HTTPS URL or data URI containing an encoded image.' + ) + + +class RunwayPromptImageObject( + RootModel[Union[str, List[RunwayPromptImageDetailedObject]]] +): + root: Union[str, List[RunwayPromptImageDetailedObject]] = Field( + ..., + description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.', + ) + + +class RunwayModelEnum(str, Enum): + gen4_turbo = 'gen4_turbo' + gen3a_turbo = 'gen3a_turbo' + + +class RunwayDurationEnum(int, Enum): + integer_5 = 5 + integer_10 = 10 + + +class RunwayImageToVideoRequest(BaseModel): + duration: RunwayDurationEnum + model: RunwayModelEnum + promptImage: RunwayPromptImageObject + promptText: Optional[str] = Field( + None, description='Text prompt for the generation', max_length=1000 + ) + ratio: RunwayAspectRatioEnum + seed: int = Field( + ..., description='Random seed for generation', ge=0, le=4294967295 + ) + + +class RunwayImageToVideoResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') + + +class RunwayTaskStatusEnum(str, Enum): + SUCCEEDED = 'SUCCEEDED' + RUNNING = 'RUNNING' + FAILED = 'FAILED' + PENDING = 'PENDING' + CANCELLED = 'CANCELLED' + THROTTLED = 'THROTTLED' + + +class RunwayTaskStatusResponse(BaseModel): + createdAt: datetime = Field(..., description='Task creation timestamp') + id: str = Field(..., description='Task ID') + output: Optional[List[str]] = Field(None, description='Array of output video URLs') + progress: Optional[float] = Field( + None, + description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.', + ge=0.0, + le=1.0, + ) + status: RunwayTaskStatusEnum + + +class Model4(str, Enum): + gen4_image = 'gen4_image' + + +class ReferenceImage(BaseModel): + uri: Optional[str] = Field( + None, description='A HTTPS URL or data URI containing an encoded image' + ) + + +class RunwayTextToImageAspectRatioEnum(str, Enum): + field_1920_1080 = '1920:1080' + field_1080_1920 = '1080:1920' + field_1024_1024 = '1024:1024' + field_1360_768 = '1360:768' + field_1080_1080 = '1080:1080' + field_1168_880 = '1168:880' + field_1440_1080 = '1440:1080' + field_1080_1440 = '1080:1440' + field_1808_768 = '1808:768' + field_2112_912 = '2112:912' + + +class RunwayTextToImageRequest(BaseModel): + model: Model4 = Field(..., description='Model to use for generation') + promptText: str = Field( + ..., description='Text prompt for the image generation', max_length=1000 + ) + ratio: RunwayTextToImageAspectRatioEnum + referenceImages: Optional[List[ReferenceImage]] = Field( + None, description='Array of reference images to guide the generation' + ) + + +class RunwayTextToImageResponse(BaseModel): + id: Optional[str] = Field(None, description='Task ID') diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability.py similarity index 100% rename from comfy_api_nodes/apis/stability_api.py rename to comfy_api_nodes/apis/stability.py diff --git a/comfy_api_nodes/apis/topaz_api.py b/comfy_api_nodes/apis/topaz.py similarity index 97% rename from comfy_api_nodes/apis/topaz_api.py rename to comfy_api_nodes/apis/topaz.py index 4d9e62e72..a9e6235a7 100644 --- a/comfy_api_nodes/apis/topaz_api.py +++ b/comfy_api_nodes/apis/topaz.py @@ -41,7 +41,7 @@ class Resolution(BaseModel): height: int = Field(...) -class CreateCreateVideoRequestSource(BaseModel): +class CreateVideoRequestSource(BaseModel): container: str = Field(...) size: int = Field(..., description="Size of the video file in bytes") duration: int = Field(..., description="Duration of the video file in seconds") @@ -89,7 +89,7 @@ class Overrides(BaseModel): class CreateVideoRequest(BaseModel): - source: CreateCreateVideoRequestSource = Field(...) + source: CreateVideoRequestSource = Field(...) filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) output: OutputInformationVideo = Field(...) overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo.py similarity index 100% rename from comfy_api_nodes/apis/tripo_api.py rename to comfy_api_nodes/apis/tripo.py diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo.py similarity index 100% rename from comfy_api_nodes/apis/veo_api.py rename to comfy_api_nodes/apis/veo.py diff --git a/comfy_api_nodes/apis/vidu.py b/comfy_api_nodes/apis/vidu.py new file mode 100644 index 000000000..469adcdbc --- /dev/null +++ b/comfy_api_nodes/apis/vidu.py @@ -0,0 +1,65 @@ +from pydantic import BaseModel, Field + + +class SubjectReference(BaseModel): + id: str = Field(...) + images: list[str] = Field(...) + + +class FrameSetting(BaseModel): + prompt: str = Field(...) + key_image: str = Field(...) + duration: int = Field(...) + + +class TaskMultiFrameCreationRequest(BaseModel): + model: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + resolution: str = Field(...) + start_image: str = Field(...) + image_settings: list[FrameSetting] = Field(...) + + +class TaskExtendCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(..., max_length=2000) + duration: int = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + resolution: str = Field(...) + images: list[str] | None = Field(None, description="Base64 encoded string or image URL") + video_url: str = Field(..., description="URL of the video to extend") + + +class TaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(..., max_length=2000) + duration: int = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + aspect_ratio: str | None = Field(None) + resolution: str | None = Field(None) + movement_amplitude: str | None = Field(None) + images: list[str] | None = Field(None, description="Base64 encoded string or image URL") + subjects: list[SubjectReference] | None = Field(None) + bgm: bool | None = Field(None) + audio: bool | None = Field(None) + + +class TaskCreationResponse(BaseModel): + task_id: str = Field(...) + state: str = Field(...) + created_at: str = Field(...) + code: int | None = Field(None, description="Error code") + + +class TaskResult(BaseModel): + id: str = Field(..., description="Creation id") + url: str = Field(..., description="The URL of the generated results, valid for one hour") + cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") + + +class TaskStatusResponse(BaseModel): + state: str = Field(...) + err_code: str | None = Field(None) + progress: float | None = Field(None) + credits: int | None = Field(None) + creations: list[TaskResult] = Field(..., description="Generated results") diff --git a/comfy_api_nodes/apis/wavespeed.py b/comfy_api_nodes/apis/wavespeed.py new file mode 100644 index 000000000..07a7bfa5d --- /dev/null +++ b/comfy_api_nodes/apis/wavespeed.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, Field + + +class SeedVR2ImageRequest(BaseModel): + image: str = Field(...) + target_resolution: str = Field(...) + output_format: str = Field("png") + enable_sync_mode: bool = Field(False) + + +class FlashVSRRequest(BaseModel): + target_resolution: str = Field(...) + video: str = Field(...) + duration: float = Field(...) + + +class TaskCreatedDataResponse(BaseModel): + id: str = Field(...) + + +class TaskCreatedResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskCreatedDataResponse | None = Field(None) + + +class TaskResultDataResponse(BaseModel): + status: str = Field(...) + outputs: list[str] = Field([]) + + +class TaskResultResponse(BaseModel): + code: int = Field(...) + message: str = Field(...) + data: TaskResultDataResponse | None = Field(None) diff --git a/comfy_api_nodes/canary.py b/comfy_api_nodes/canary.py deleted file mode 100644 index 4df7590b6..000000000 --- a/comfy_api_nodes/canary.py +++ /dev/null @@ -1,10 +0,0 @@ -import av - -ver = av.__version__.split(".") -if int(ver[0]) < 14: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -if int(ver[0]) == 14 and int(ver[1]) < 2: - raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.") - -NODE_CLASS_MAPPINGS = {} diff --git a/comfy_api_nodes/mapper_utils.py b/comfy_api_nodes/mapper_utils.py deleted file mode 100644 index 6fab8f4bb..000000000 --- a/comfy_api_nodes/mapper_utils.py +++ /dev/null @@ -1,116 +0,0 @@ -from enum import Enum - -from pydantic.fields import FieldInfo -from pydantic import BaseModel -from pydantic_core import PydanticUndefined - -from comfy.comfy_types.node_typing import IO, InputTypeOptions - -NodeInput = tuple[IO, InputTypeOptions] - - -def _create_base_config(field_info: FieldInfo) -> InputTypeOptions: - config = {} - if hasattr(field_info, "default") and field_info.default is not PydanticUndefined: - config["default"] = field_info.default - if hasattr(field_info, "description") and field_info.description is not None: - config["tooltip"] = field_info.description - return config - - -def _get_number_constraints_config(field_info: FieldInfo) -> dict: - config = {} - if hasattr(field_info, "metadata"): - metadata = field_info.metadata - for constraint in metadata: - if hasattr(constraint, "ge"): - config["min"] = constraint.ge - if hasattr(constraint, "le"): - config["max"] = constraint.le - if hasattr(constraint, "multiple_of"): - config["step"] = constraint.multiple_of - return config - - -def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.IMAGE, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.STRING, { - **_create_base_config(field_info), - **kwargs, - } - - -def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.FLOAT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput: - return IO.INT, { - **_create_base_config(field_info), - **_get_number_constraints_config(field_info), - **kwargs, - } - - -def _model_field_to_combo_input( - field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs -) -> NodeInput: - combo_config = {} - if enum_type is not None: - combo_config["options"] = [option.value for option in enum_type] - combo_config = { - **combo_config, - **_create_base_config(field_info), - **kwargs, - } - return IO.COMBO, combo_config - - -def model_field_to_node_input( - input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs -) -> NodeInput: - """ - Maps a field from a Pydantic model to a Comfy node input. - - Args: - input_type: The type of the input. - base_model: The Pydantic model to map the field from. - field_name: The name of the field to map. - **kwargs: Additional key/values to include in the input options. - - Note: - For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically. - - Example: - >>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True) - >>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum) - >>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True) - """ - field_info: FieldInfo = base_model.model_fields[field_name] - result: NodeInput - - if input_type == IO.IMAGE: - result = _model_field_to_image_input(field_info, **kwargs) - elif input_type == IO.STRING: - result = _model_field_to_string_input(field_info, **kwargs) - elif input_type == IO.FLOAT: - result = _model_field_to_float_input(field_info, **kwargs) - elif input_type == IO.INT: - result = _model_field_to_int_input(field_info, **kwargs) - elif input_type == IO.COMBO: - result = _model_field_to_combo_input(field_info, **kwargs) - else: - message = f"Invalid input type: {input_type}" - raise ValueError(message) - - return result diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 8826dea0c..23590bf24 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,11 +1,9 @@ -from inspect import cleandoc - import torch from pydantic import BaseModel from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.bfl_api import ( +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bfl import ( BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, @@ -28,7 +26,7 @@ from comfy_api_nodes.util import ( ) -def convert_mask_to_image(mask: torch.Tensor): +def convert_mask_to_image(mask: Input.Image): """ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ @@ -38,9 +36,6 @@ def convert_mask_to_image(mask: torch.Tensor): class FluxProUltraImageNode(IO.ComfyNode): - """ - Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -48,7 +43,7 @@ class FluxProUltraImageNode(IO.ComfyNode): node_id="FluxProUltraImageNode", display_name="Flux 1.1 [pro] Ultra Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", inputs=[ IO.String.Input( "prompt", @@ -62,6 +57,7 @@ class FluxProUltraImageNode(IO.ComfyNode): tooltip="Whether to perform upsampling on the prompt. " "If active, automatically modifies the prompt for more creative generation, " "but results are nondeterministic (same seed will not produce exactly the same result).", + advanced=True, ), IO.Int.Input( "seed", @@ -102,6 +98,9 @@ class FluxProUltraImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.06}""", + ), ) @classmethod @@ -117,7 +116,7 @@ class FluxProUltraImageNode(IO.ComfyNode): prompt_upsampling: bool = False, raw: bool = False, seed: int = 0, - image_prompt: torch.Tensor | None = None, + image_prompt: Input.Image | None = None, image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: @@ -155,9 +154,6 @@ class FluxProUltraImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode): - """ - Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -165,7 +161,7 @@ class FluxKontextProImageNode(IO.ComfyNode): node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", inputs=[ IO.String.Input( "prompt", @@ -205,6 +201,7 @@ class FluxKontextProImageNode(IO.ComfyNode): "prompt_upsampling", default=False, tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + advanced=True, ), IO.Image.Input( "input_image", @@ -231,7 +228,7 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: torch.Tensor | None = None, + input_image: Input.Image | None = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: @@ -271,20 +268,14 @@ class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextMaxImageNode(FluxKontextProImageNode): - """ - Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio. - """ - DESCRIPTION = cleandoc(__doc__ or "") + DESCRIPTION = "Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio." BFL_PATH = "/proxy/bfl/flux-kontext-max/generate" NODE_ID = "FluxKontextMaxImageNode" DISPLAY_NAME = "Flux.1 Kontext [max] Image" class FluxProExpandNode(IO.ComfyNode): - """ - Outpaints image based on prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -292,7 +283,7 @@ class FluxProExpandNode(IO.ComfyNode): node_id="FluxProExpandNode", display_name="Flux.1 Expand Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Outpaints image based on prompt.", inputs=[ IO.Image.Input("image"), IO.String.Input( @@ -307,6 +298,7 @@ class FluxProExpandNode(IO.ComfyNode): tooltip="Whether to perform upsampling on the prompt. " "If active, automatically modifies the prompt for more creative generation, " "but results are nondeterministic (same seed will not produce exactly the same result).", + advanced=True, ), IO.Int.Input( "top", @@ -366,12 +358,15 @@ class FluxProExpandNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.05}""", + ), ) @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, prompt: str, prompt_upsampling: bool, top: int, @@ -418,9 +413,6 @@ class FluxProExpandNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode): - """ - Inpaints image based on mask and prompt. - """ @classmethod def define_schema(cls) -> IO.Schema: @@ -428,7 +420,7 @@ class FluxProFillNode(IO.ComfyNode): node_id="FluxProFillNode", display_name="Flux.1 Fill Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Inpaints image based on mask and prompt.", inputs=[ IO.Image.Input("image"), IO.Mask.Input("mask"), @@ -444,6 +436,7 @@ class FluxProFillNode(IO.ComfyNode): tooltip="Whether to perform upsampling on the prompt. " "If active, automatically modifies the prompt for more creative generation, " "but results are nondeterministic (same seed will not produce exactly the same result).", + advanced=True, ), IO.Float.Input( "guidance", @@ -475,13 +468,16 @@ class FluxProFillNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.05}""", + ), ) @classmethod async def execute( cls, - image: torch.Tensor, - mask: torch.Tensor, + image: Input.Image, + mask: Input.Image, prompt: str, prompt_upsampling: bool, steps: int, @@ -525,11 +521,30 @@ class FluxProFillNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode): + NODE_ID = "Flux2ProImageNode" + DISPLAY_NAME = "Flux.2 [pro] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-pro/generate" + PRICE_BADGE_EXPR = """ + ( + $MP := 1024 * 1024; + $outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]); + $outputCost := 0.03 + 0.015 * ($outMP - 1); + inputs.images.connected + ? { + "type":"range_usd", + "min_usd": $outputCost + 0.015, + "max_usd": $outputCost + 0.12, + "format": { "approximate": true } + } + : {"type":"usd","usd": $outputCost} + ) + """ + @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( - node_id="Flux2ProImageNode", - display_name="Flux.2 [pro] Image", + node_id=cls.NODE_ID, + display_name=cls.DISPLAY_NAME, category="api node/image/BFL", description="Generates images synchronously based on prompt and resolution.", inputs=[ @@ -563,12 +578,12 @@ class Flux2ProImageNode(IO.ComfyNode): ), IO.Boolean.Input( "prompt_upsampling", - default=False, + default=True, tooltip="Whether to perform upsampling on the prompt. " - "If active, automatically modifies the prompt for more creative generation, " - "but results are nondeterministic (same seed will not produce exactly the same result).", + "If active, automatically modifies the prompt for more creative generation.", + advanced=True, ), - IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), + IO.Image.Input("images", optional=True, tooltip="Up to 9 images to be used as references."), ], outputs=[IO.Image.Output()], hidden=[ @@ -577,6 +592,10 @@ class Flux2ProImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["width", "height"], inputs=["images"]), + expr=cls.PRICE_BADGE_EXPR, + ), ) @classmethod @@ -587,7 +606,7 @@ class Flux2ProImageNode(IO.ComfyNode): height: int, seed: int, prompt_upsampling: bool, - images: torch.Tensor | None = None, + images: Input.Image | None = None, ) -> IO.NodeOutput: reference_images = {} if images is not None: @@ -598,7 +617,7 @@ class Flux2ProImageNode(IO.ComfyNode): reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) initial_response = await sync_op( cls, - ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + ApiEndpoint(path=cls.API_ENDPOINT, method="POST"), response_model=BFLFluxProGenerateResponse, data=Flux2ProGenerateRequest( prompt=prompt, @@ -632,6 +651,29 @@ class Flux2ProImageNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class Flux2MaxImageNode(Flux2ProImageNode): + + NODE_ID = "Flux2MaxImageNode" + DISPLAY_NAME = "Flux.2 [max] Image" + API_ENDPOINT = "/proxy/bfl/flux-2-max/generate" + PRICE_BADGE_EXPR = """ + ( + $MP := 1024 * 1024; + $outMP := $max([1, $floor(((widgets.width * widgets.height) + $MP - 1) / $MP)]); + $outputCost := 0.07 + 0.03 * ($outMP - 1); + + inputs.images.connected + ? { + "type":"range_usd", + "min_usd": $outputCost + 0.03, + "max_usd": $outputCost + 0.24, + "format": { "approximate": true } + } + : {"type":"usd","usd": $outputCost} + ) + """ + + class BFLExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -642,6 +684,7 @@ class BFLExtension(ComfyExtension): FluxProExpandNode, FluxProFillNode, Flux2ProImageNode, + Flux2MaxImageNode, ] diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py new file mode 100644 index 000000000..4044ee3ea --- /dev/null +++ b/comfy_api_nodes/nodes_bria.py @@ -0,0 +1,330 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bria import ( + BriaEditImageRequest, + BriaRemoveBackgroundRequest, + BriaRemoveBackgroundResponse, + BriaRemoveVideoBackgroundRequest, + BriaRemoveVideoBackgroundResponse, + BriaImageEditResponse, + BriaStatusResponse, + InputModerationSettings, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + convert_mask_to_image, + download_url_to_image_tensor, + download_url_to_video_output, + poll_op, + sync_op, + upload_image_to_comfyapi, + upload_video_to_comfyapi, + validate_video_duration, +) + + +class BriaImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaImageEditNode", + display_name="Bria FIBO Image Edit", + category="api node/image/Bria", + description="Edit images using Bria latest model", + inputs=[ + IO.Combo.Input("model", options=["FIBO"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Instruction to edit image", + ), + IO.String.Input("negative_prompt", multiline=True, default=""), + IO.String.Input( + "structured_prompt", + multiline=True, + default="", + tooltip="A string containing the structured edit prompt in JSON format. " + "Use this instead of usual prompt for precise, programmatic control.", + ), + IO.Int.Input( + "seed", + default=1, + min=1, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Float.Input( + "guidance_scale", + default=3, + min=3, + max=5, + step=0.01, + display_mode=IO.NumberDisplay.number, + tooltip="Higher value makes the image follow the prompt more closely.", + ), + IO.Int.Input( + "steps", + default=50, + min=20, + max=50, + step=1, + display_mode=IO.NumberDisplay.number, + ), + IO.DynamicCombo.Input( + "moderation", + options=[ + IO.DynamicCombo.Option("false", []), + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input("prompt_content_moderation", default=False), + IO.Boolean.Input("visual_input_moderation", default=False), + IO.Boolean.Input("visual_output_moderation", default=True), + ], + ), + ], + tooltip="Moderation settings", + ), + IO.Mask.Input( + "mask", + tooltip="If omitted, the edit applies to the entire image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(display_name="structured_prompt"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.04}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + negative_prompt: str, + structured_prompt: str, + seed: int, + guidance_scale: float, + steps: int, + moderation: InputModerationSettings, + mask: Input.Image | None = None, + ) -> IO.NodeOutput: + if not prompt and not structured_prompt: + raise ValueError("One of prompt or structured_prompt is required to be non-empty.") + mask_url = None + if mask is not None: + mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask") + response = await sync_op( + cls, + ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"), + data=BriaEditImageRequest( + instruction=prompt if prompt else None, + structured_instruction=structured_prompt if structured_prompt else None, + images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")], + mask=mask_url, + negative_prompt=negative_prompt if negative_prompt else None, + guidance_scale=guidance_scale, + seed=seed, + model_version=model, + steps_num=steps, + prompt_content_moderation=moderation.get("prompt_content_moderation", False), + visual_input_content_moderation=moderation.get("visual_input_moderation", False), + visual_output_content_moderation=moderation.get("visual_output_moderation", False), + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaImageEditResponse, + ) + return IO.NodeOutput( + await download_url_to_image_tensor(response.result.image_url), + response.result.structured_prompt, + ) + + +class BriaRemoveImageBackground(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaRemoveImageBackground", + display_name="Bria Remove Image Background", + category="api node/image/Bria", + description="Remove the background from an image using Bria RMBG 2.0.", + inputs=[ + IO.Image.Input("image"), + IO.DynamicCombo.Input( + "moderation", + options=[ + IO.DynamicCombo.Option("false", []), + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input("visual_input_moderation", default=False), + IO.Boolean.Input("visual_output_moderation", default=True), + ], + ), + ], + tooltip="Moderation settings", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.018}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + moderation: dict, + seed: int, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"), + data=BriaRemoveBackgroundRequest( + image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"), + sync=False, + visual_input_content_moderation=moderation.get("visual_input_moderation", False), + visual_output_content_moderation=moderation.get("visual_output_moderation", False), + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveBackgroundResponse, + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url)) + + +class BriaRemoveVideoBackground(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaRemoveVideoBackground", + display_name="Bria Remove Video Background", + category="api node/video/Bria", + description="Remove the background from a video using Bria. ", + inputs=[ + IO.Video.Input("video"), + IO.Combo.Input( + "background_color", + options=[ + "Black", + "White", + "Gray", + "Red", + "Green", + "Blue", + "Yellow", + "Cyan", + "Magenta", + "Orange", + ], + tooltip="Background color for the output video.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + background_color: str, + seed: int, + ) -> IO.NodeOutput: + validate_video_duration(video, max_duration=60.0) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"), + data=BriaRemoveVideoBackgroundRequest( + video=await upload_video_to_comfyapi(cls, video), + background_color=background_color, + output_container_and_codec="mp4_h264", + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveVideoBackgroundResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) + + +class BriaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + BriaImageEditNode, + BriaRemoveImageBackground, + BriaRemoveVideoBackground, + ] + + +async def comfy_entrypoint() -> BriaExtension: + return BriaExtension() diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 57c0218d0..de0c22e70 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -5,11 +5,10 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis.bytedance_api import ( +from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, VIDEO_TASKS_EXECUTION_TIME, - Image2ImageTaskCreationRequest, Image2VideoTaskCreationRequest, ImageTaskCreationResponse, Seedream4Options, @@ -38,10 +37,20 @@ from comfy_api_nodes.util import ( BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" +SEEDREAM_MODELS = { + "seedream 5.0 lite": "seedream-5-0-260128", + "seedream-4-5-251128": "seedream-4-5-251128", + "seedream-4-0-250828": "seedream-4-0-250828", +} + # Long-running tasks endpoints(e.g., video) BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} +DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"} + +logger = logging.getLogger(__name__) + def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: @@ -112,9 +121,10 @@ class ByteDanceImageNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image', optional=True, + advanced=True, ), ], outputs=[ @@ -126,6 +136,10 @@ class ByteDanceImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), + is_deprecated=True, ) @classmethod @@ -171,112 +185,19 @@ class ByteDanceImageNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) -class ByteDanceImageEditNode(IO.ComfyNode): - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="ByteDanceImageEditNode", - display_name="ByteDance Image Edit", - category="api node/image/ByteDance", - description="Edit images using ByteDance models via api based on prompt", - inputs=[ - IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]), - IO.Image.Input( - "image", - tooltip="The base image to edit", - ), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Instruction to edit image", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=2147483647, - step=1, - display_mode=IO.NumberDisplay.number, - control_after_generate=True, - tooltip="Seed to use for generation", - optional=True, - ), - IO.Float.Input( - "guidance_scale", - default=5.5, - min=1.0, - max=10.0, - step=0.01, - display_mode=IO.NumberDisplay.number, - tooltip="Higher value makes the image follow the prompt more closely", - optional=True, - ), - IO.Boolean.Input( - "watermark", - default=True, - tooltip='Whether to add an "AI generated" watermark to the image', - optional=True, - ), - ], - outputs=[ - IO.Image.Output(), - ], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - model: str, - image: Input.Image, - prompt: str, - seed: int, - guidance_scale: float, - watermark: bool, - ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=True, min_length=1) - if get_number_of_images(image) != 1: - raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio(image, (1, 3), (3, 1)) - source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] - payload = Image2ImageTaskCreationRequest( - model=model, - prompt=prompt, - image=source_url, - seed=seed, - guidance_scale=guidance_scale, - watermark=watermark, - ) - response = await sync_op( - cls, - ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), - data=payload, - response_model=ImageTaskCreationResponse, - ) - return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) - - class ByteDanceSeedreamNode(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="ByteDanceSeedreamNode", - display_name="ByteDance Seedream 4", + display_name="ByteDance Seedream 4.5 & 5.0", category="api node/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.Combo.Input( "model", - options=["seedream-4-5-251128", "seedream-4-0-250828"], - tooltip="Model name", + options=list(SEEDREAM_MODELS.keys()), ), IO.String.Input( "prompt", @@ -287,7 +208,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Image.Input( "image", tooltip="Input image(s) for image-to-image generation. " - "List of 1-10 images for single or multi-reference generation.", + "Reference image(s) for single or multi-reference generation.", optional=True, ), IO.Combo.Input( @@ -299,8 +220,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode): "width", default=2048, min=1024, - max=4096, - step=8, + max=6240, + step=2, tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -308,8 +229,8 @@ class ByteDanceSeedreamNode(IO.ComfyNode): "height", default=2048, min=1024, - max=4096, - step=8, + max=4992, + step=2, tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -346,15 +267,17 @@ class ByteDanceSeedreamNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the image.', optional=True, + advanced=True, ), IO.Boolean.Input( "fail_on_partial", default=True, tooltip="If enabled, abort execution if any requested images are missing or return an error.", optional=True, + advanced=True, ), ], outputs=[ @@ -366,6 +289,20 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $price := $contains(widgets.model, "5.0 lite") ? 0.035 : + $contains(widgets.model, "4-5") ? 0.04 : 0.03; + { + "type":"usd", + "usd": $price, + "format": { "suffix":" x images/Run", "approximate": true } + } + ) + """, + ), ) @classmethod @@ -380,9 +317,10 @@ class ByteDanceSeedreamNode(IO.ComfyNode): sequential_image_generation: str = "disabled", max_images: int = 1, seed: int = 0, - watermark: bool = True, + watermark: bool = False, fail_on_partial: bool = True, ) -> IO.NodeOutput: + model = SEEDREAM_MODELS[model] validate_string(prompt, strip_whitespace=True, min_length=1) w = h = None for label, tw, th in RECOMMENDED_PRESETS_SEEDREAM_4: @@ -392,15 +330,12 @@ class ByteDanceSeedreamNode(IO.ComfyNode): if w is None or h is None: w, h = width, height - if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): - raise ValueError( - f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." - ) + out_num_pixels = w * h mp_provided = out_num_pixels / 1_000_000.0 - if "seedream-4-5" in model and out_num_pixels < 3686400: + if ("seedream-4-5" in model or "seedream-5-0" in model) and out_num_pixels < 3686400: raise ValueError( - f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, " + f"Minimum image resolution for the selected model is 3.68MP, " f"but {mp_provided:.2f}MP provided." ) if "seedream-4-0" in model and out_num_pixels < 921600: @@ -408,9 +343,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode): f"Minimum image resolution that the selected model can generate is 0.92MP, " f"but {mp_provided:.2f}MP provided." ) + max_pixels = 10_404_496 if "seedream-5-0" in model else 16_777_216 + if out_num_pixels > max_pixels: + raise ValueError( + f"Maximum image resolution for the selected model is {max_pixels / 1_000_000:.2f}MP, " + f"but {mp_provided:.2f}MP provided." + ) n_input_images = get_number_of_images(image) if image is not None else 0 - if n_input_images > 10: - raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") + max_num_of_images = 14 if model == "seedream-5-0-260128" else 10 + if n_input_images > max_num_of_images: + raise ValueError( + f"Maximum of {max_num_of_images} reference images are supported, but {n_input_images} received." + ) if sequential_image_generation == "auto" and n_input_images + max_images > 15: raise ValueError( "The maximum number of generated images plus the number of reference images cannot exceed 15." @@ -438,6 +382,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): sequential_image_generation=sequential_image_generation, sequential_image_generation_options=Seedream4Options(max_images=max_images), watermark=watermark, + output_format="png" if model == "seedream-5-0-260128" else None, ), ) if len(response.data) == 1: @@ -460,7 +405,12 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-t2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -504,12 +454,21 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): tooltip="Specifies whether to fix the camera. The platform appends an instruction " "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, + advanced=True, ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, + advanced=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + advanced=True, ), ], outputs=[ @@ -521,6 +480,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -534,7 +494,10 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) @@ -549,7 +512,11 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): ) return await process_video_task( cls, - payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), + payload=Text2VideoTaskCreationRequest( + model=model, + content=[TaskTextContent(text=prompt)], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, + ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -566,7 +533,12 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + options=[ + "seedance-1-5-pro-251215", + "seedance-1-0-pro-250528", + "seedance-1-0-lite-i2v-250428", + "seedance-1-0-pro-fast-251015", + ], default="seedance-1-0-pro-fast-251015", ), IO.String.Input( @@ -614,12 +586,21 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): tooltip="Specifies whether to fix the camera. The platform appends an instruction " "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, + advanced=True, ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, + advanced=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + advanced=True, ), ], outputs=[ @@ -631,6 +612,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -645,7 +627,10 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) @@ -667,6 +652,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -684,7 +670,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + options=["seedance-1-5-pro-251215", "seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( @@ -736,12 +722,21 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): tooltip="Specifies whether to fix the camera. The platform appends an instruction " "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, + advanced=True, ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, + advanced=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + tooltip="This parameter is ignored for any model except seedance-1-5-pro.", + optional=True, + advanced=True, ), ], outputs=[ @@ -753,6 +748,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -768,7 +764,10 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): seed: int, camera_fixed: bool, watermark: bool, + generate_audio: bool = False, ) -> IO.NodeOutput: + if model == "seedance-1-5-pro-251215" and duration < 4: + raise ValueError("Minimum supported duration for Seedance 1.5 Pro is 4 seconds.") validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) for i in (first_frame, last_frame): @@ -801,6 +800,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[0])), role="first_frame"), TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], + generate_audio=generate_audio if model == "seedance-1-5-pro-251215" else None, ), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -862,9 +862,10 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, + advanced=True, ), ], outputs=[ @@ -876,6 +877,41 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $priceByModel := { + "seedance-1-0-pro": { + "480p":[0.23,0.24], + "720p":[0.51,0.56] + }, + "seedance-1-0-lite": { + "480p":[0.17,0.18], + "720p":[0.37,0.41] + } + }; + $model := widgets.model; + $modelKey := + $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : + "seedance-1-0-lite"; + $resolution := widgets.resolution; + $resKey := + $contains($resolution, "720") ? "720p" : + "480p"; + $modelPrices := $lookup($priceByModel, $modelKey); + $baseRange := $lookup($modelPrices, $resKey); + $min10s := $baseRange[0]; + $max10s := $baseRange[1]; + $scale := widgets.duration / 10; + $minCost := $min10s * $scale; + $maxCost := $max10s * $scale; + ($minCost = $maxCost) + ? {"type":"usd","usd": $minCost} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost} + ) + """, + ), ) @classmethod @@ -911,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ] return await process_video_task( cls, - payload=Image2VideoTaskCreationRequest(model=model, content=x), + payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -921,6 +957,12 @@ async def process_video_task( payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, estimated_duration: int | None, ) -> IO.NodeOutput: + if payload.model in DEPRECATED_MODELS: + logger.warning( + "Model '%s' is deprecated and will be deactivated on May 13, 2026. " + "Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.", + payload.model, + ) initial_response = await sync_op( cls, ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), @@ -945,12 +987,64 @@ def raise_if_text_params(prompt: str, text_params: list[str]) -> None: ) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution", "generate_audio"]), + expr=""" + ( + $priceByModel := { + "seedance-1-5-pro": { + "480p":[0.12,0.12], + "720p":[0.26,0.26], + "1080p":[0.58,0.59] + }, + "seedance-1-0-pro": { + "480p":[0.23,0.24], + "720p":[0.51,0.56], + "1080p":[1.18,1.22] + }, + "seedance-1-0-pro-fast": { + "480p":[0.09,0.1], + "720p":[0.21,0.23], + "1080p":[0.47,0.49] + }, + "seedance-1-0-lite": { + "480p":[0.17,0.18], + "720p":[0.37,0.41], + "1080p":[0.85,0.88] + } + }; + $model := widgets.model; + $modelKey := + $contains($model, "seedance-1-5-pro") ? "seedance-1-5-pro" : + $contains($model, "seedance-1-0-pro-fast") ? "seedance-1-0-pro-fast" : + $contains($model, "seedance-1-0-pro") ? "seedance-1-0-pro" : + "seedance-1-0-lite"; + $resolution := widgets.resolution; + $resKey := + $contains($resolution, "1080") ? "1080p" : + $contains($resolution, "720") ? "720p" : + "480p"; + $modelPrices := $lookup($priceByModel, $modelKey); + $baseRange := $lookup($modelPrices, $resKey); + $min10s := $baseRange[0]; + $max10s := $baseRange[1]; + $scale := widgets.duration / 10; + $audioMultiplier := ($modelKey = "seedance-1-5-pro" and widgets.generate_audio) ? 2 : 1; + $minCost := $min10s * $scale * $audioMultiplier; + $maxCost := $max10s * $scale * $audioMultiplier; + ($minCost = $maxCost) + ? {"type":"usd","usd": $minCost, "format": { "approximate": true }} + : {"type":"range_usd","min_usd": $minCost, "max_usd": $maxCost, "format": { "approximate": true }} + ) + """, +) + + class ByteDanceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ ByteDanceImageNode, - ByteDanceImageEditNode, ByteDanceSeedreamNode, ByteDanceTextToVideoNode, ByteDanceImageToVideoNode, diff --git a/comfy_api_nodes/nodes_elevenlabs.py b/comfy_api_nodes/nodes_elevenlabs.py new file mode 100644 index 000000000..e452daf77 --- /dev/null +++ b/comfy_api_nodes/nodes_elevenlabs.py @@ -0,0 +1,924 @@ +import json +import uuid + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.elevenlabs import ( + AddVoiceRequest, + AddVoiceResponse, + DialogueInput, + DialogueSettings, + SpeechToSpeechRequest, + SpeechToTextRequest, + SpeechToTextResponse, + TextToDialogueRequest, + TextToSoundEffectsRequest, + TextToSpeechRequest, + TextToSpeechVoiceSettings, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + audio_bytes_to_audio_input, + audio_ndarray_to_bytesio, + audio_tensor_to_contiguous_ndarray, + sync_op, + sync_op_raw, + upload_audio_to_comfyapi, + validate_string, +) + +ELEVENLABS_MUSIC_SECTIONS = "ELEVENLABS_MUSIC_SECTIONS" # Custom type for music sections +ELEVENLABS_COMPOSITION_PLAN = "ELEVENLABS_COMPOSITION_PLAN" # Custom type for composition plan +ELEVENLABS_VOICE = "ELEVENLABS_VOICE" # Custom type for voice selection + +# Predefined ElevenLabs voices: (voice_id, display_name, gender, accent) +ELEVENLABS_VOICES = [ + ("CwhRBWXzGAHq8TQ4Fs17", "Roger", "male", "american"), + ("EXAVITQu4vr4xnSDxMaL", "Sarah", "female", "american"), + ("FGY2WhTYpPnrIDTdsKH5", "Laura", "female", "american"), + ("IKne3meq5aSn9XLyUdCD", "Charlie", "male", "australian"), + ("JBFqnCBsd6RMkjVDRZzb", "George", "male", "british"), + ("N2lVS1w4EtoT3dr4eOWO", "Callum", "male", "american"), + ("SAz9YHcvj6GT2YYXdXww", "River", "neutral", "american"), + ("SOYHLrjzK2X1ezoPC6cr", "Harry", "male", "american"), + ("TX3LPaxmHKxFdv7VOQHJ", "Liam", "male", "american"), + ("Xb7hH8MSUJpSbSDYk0k2", "Alice", "female", "british"), + ("XrExE9yKIg1WjnnlVkGX", "Matilda", "female", "american"), + ("bIHbv24MWmeRgasZH58o", "Will", "male", "american"), + ("cgSgspJ2msm6clMCkdW9", "Jessica", "female", "american"), + ("cjVigY5qzO86Huf0OWal", "Eric", "male", "american"), + ("hpp4J3VqNfWAUOO0d1Us", "Bella", "female", "american"), + ("iP95p4xoKVk53GoZ742B", "Chris", "male", "american"), + ("nPczCjzI2devNBz1zQrb", "Brian", "male", "american"), + ("onwK4e9ZLuTAKqWW03F9", "Daniel", "male", "british"), + ("pFZP5JQG7iQjIQuC4Bku", "Lily", "female", "british"), + ("pNInz6obpgDQGcFmaJgB", "Adam", "male", "american"), + ("pqHfZKP75CvOlQylNhV4", "Bill", "male", "american"), +] + +ELEVENLABS_VOICE_OPTIONS = [f"{name} ({gender}, {accent})" for _, name, gender, accent in ELEVENLABS_VOICES] +ELEVENLABS_VOICE_MAP = { + f"{name} ({gender}, {accent})": voice_id for voice_id, name, gender, accent in ELEVENLABS_VOICES +} + + +class ElevenLabsSpeechToText(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ElevenLabsSpeechToText", + display_name="ElevenLabs Speech to Text", + category="api node/audio/ElevenLabs", + description="Transcribe audio to text. " + "Supports automatic language detection, speaker diarization, and audio event tagging.", + inputs=[ + IO.Audio.Input( + "audio", + tooltip="Audio to transcribe.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "scribe_v2", + [ + IO.Boolean.Input( + "tag_audio_events", + default=False, + tooltip="Annotate sounds like (laughter), (music), etc. in transcript.", + ), + IO.Boolean.Input( + "diarize", + default=False, + tooltip="Annotate which speaker is talking.", + ), + IO.Float.Input( + "diarization_threshold", + default=0.22, + min=0.1, + max=0.4, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Speaker separation sensitivity. " + "Lower values are more sensitive to speaker changes.", + ), + IO.Float.Input( + "temperature", + default=0.0, + min=0.0, + max=2.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Randomness control. " + "0.0 uses model default. Higher values increase randomness.", + ), + IO.Combo.Input( + "timestamps_granularity", + options=["word", "character", "none"], + default="word", + tooltip="Timing precision for transcript words.", + ), + ], + ), + ], + tooltip="Model to use for transcription.", + ), + IO.String.Input( + "language_code", + default="", + tooltip="ISO-639-1 or ISO-639-3 language code (e.g., 'en', 'es', 'fra'). " + "Leave empty for automatic detection.", + ), + IO.Int.Input( + "num_speakers", + default=0, + min=0, + max=32, + display_mode=IO.NumberDisplay.slider, + tooltip="Maximum number of speakers to predict. Set to 0 for automatic detection.", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + tooltip="Seed for reproducibility (determinism not guaranteed).", + ), + ], + outputs=[ + IO.String.Output(display_name="text"), + IO.String.Output(display_name="language_code"), + IO.String.Output(display_name="words_json"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.0073,"format":{"approximate":true,"suffix":"/minute"}}""", + ), + ) + + @classmethod + async def execute( + cls, + audio: Input.Audio, + model: dict, + language_code: str, + num_speakers: int, + seed: int, + ) -> IO.NodeOutput: + if model["diarize"] and num_speakers: + raise ValueError( + "Number of speakers cannot be specified when diarization is enabled. " + "Either disable diarization or set num_speakers to 0." + ) + request = SpeechToTextRequest( + model_id=model["model"], + cloud_storage_url=await upload_audio_to_comfyapi( + cls, audio, container_format="mp4", codec_name="aac", mime_type="audio/mp4" + ), + language_code=language_code if language_code.strip() else None, + tag_audio_events=model["tag_audio_events"], + num_speakers=num_speakers if num_speakers > 0 else None, + timestamps_granularity=model["timestamps_granularity"], + diarize=model["diarize"], + diarization_threshold=model["diarization_threshold"] if model["diarize"] else None, + seed=seed, + temperature=model["temperature"], + ) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/elevenlabs/v1/speech-to-text", method="POST"), + response_model=SpeechToTextResponse, + data=request, + content_type="multipart/form-data", + ) + words_json = json.dumps( + [w.model_dump(exclude_none=True) for w in response.words] if response.words else [], + indent=2, + ) + return IO.NodeOutput(response.text, response.language_code, words_json) + + +class ElevenLabsVoiceSelector(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ElevenLabsVoiceSelector", + display_name="ElevenLabs Voice Selector", + category="api node/audio/ElevenLabs", + description="Select a predefined ElevenLabs voice for text-to-speech generation.", + inputs=[ + IO.Combo.Input( + "voice", + options=ELEVENLABS_VOICE_OPTIONS, + tooltip="Choose a voice from the predefined ElevenLabs voices.", + ), + ], + outputs=[ + IO.Custom(ELEVENLABS_VOICE).Output(display_name="voice"), + ], + is_api_node=False, + ) + + @classmethod + def execute(cls, voice: str) -> IO.NodeOutput: + voice_id = ELEVENLABS_VOICE_MAP.get(voice) + if not voice_id: + raise ValueError(f"Unknown voice: {voice}") + return IO.NodeOutput(voice_id) + + +class ElevenLabsTextToSpeech(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ElevenLabsTextToSpeech", + display_name="ElevenLabs Text to Speech", + category="api node/audio/ElevenLabs", + description="Convert text to speech.", + inputs=[ + IO.Custom(ELEVENLABS_VOICE).Input( + "voice", + tooltip="Voice to use for speech synthesis. Connect from Voice Selector or Instant Voice Clone.", + ), + IO.String.Input( + "text", + multiline=True, + default="", + tooltip="The text to convert to speech.", + ), + IO.Float.Input( + "stability", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Voice stability. Lower values give broader emotional range, " + "higher values produce more consistent but potentially monotonous speech.", + ), + IO.Combo.Input( + "apply_text_normalization", + options=["auto", "on", "off"], + tooltip="Text normalization mode. 'auto' lets the system decide, " + "'on' always applies normalization, 'off' skips it.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "eleven_multilingual_v2", + [ + IO.Float.Input( + "speed", + default=1.0, + min=0.7, + max=1.3, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Speech speed. 1.0 is normal, <1.0 slower, >1.0 faster.", + ), + IO.Float.Input( + "similarity_boost", + default=0.75, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Similarity boost. Higher values make the voice more similar to the original.", + ), + IO.Boolean.Input( + "use_speaker_boost", + default=False, + tooltip="Boost similarity to the original speaker voice.", + ), + IO.Float.Input( + "style", + default=0.0, + min=0.0, + max=0.2, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Style exaggeration. Higher values increase stylistic expression " + "but may reduce stability.", + ), + ], + ), + IO.DynamicCombo.Option( + "eleven_v3", + [ + IO.Float.Input( + "speed", + default=1.0, + min=0.7, + max=1.3, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Speech speed. 1.0 is normal, <1.0 slower, >1.0 faster.", + ), + IO.Float.Input( + "similarity_boost", + default=0.75, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Similarity boost. Higher values make the voice more similar to the original.", + ), + ], + ), + ], + tooltip="Model to use for text-to-speech.", + ), + IO.String.Input( + "language_code", + default="", + tooltip="ISO-639-1 or ISO-639-3 language code (e.g., 'en', 'es', 'fra'). " + "Leave empty for automatic detection.", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + tooltip="Seed for reproducibility (determinism not guaranteed).", + ), + IO.Combo.Input( + "output_format", + options=["mp3_44100_192", "opus_48000_192"], + tooltip="Audio output format.", + ), + ], + outputs=[ + IO.Audio.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/1K chars"}}""", + ), + ) + + @classmethod + async def execute( + cls, + voice: str, + text: str, + stability: float, + apply_text_normalization: str, + model: dict, + language_code: str, + seed: int, + output_format: str, + ) -> IO.NodeOutput: + validate_string(text, min_length=1) + request = TextToSpeechRequest( + text=text, + model_id=model["model"], + language_code=language_code if language_code.strip() else None, + voice_settings=TextToSpeechVoiceSettings( + stability=stability, + similarity_boost=model["similarity_boost"], + speed=model["speed"], + use_speaker_boost=model.get("use_speaker_boost", None), + style=model.get("style", None), + ), + seed=seed, + apply_text_normalization=apply_text_normalization, + ) + response = await sync_op_raw( + cls, + ApiEndpoint( + path=f"/proxy/elevenlabs/v1/text-to-speech/{voice}", + method="POST", + query_params={"output_format": output_format}, + ), + data=request, + as_binary=True, + ) + return IO.NodeOutput(audio_bytes_to_audio_input(response)) + + +class ElevenLabsAudioIsolation(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ElevenLabsAudioIsolation", + display_name="ElevenLabs Voice Isolation", + category="api node/audio/ElevenLabs", + description="Remove background noise from audio, isolating vocals or speech.", + inputs=[ + IO.Audio.Input( + "audio", + tooltip="Audio to process for background noise removal.", + ), + ], + outputs=[ + IO.Audio.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/minute"}}""", + ), + ) + + @classmethod + async def execute( + cls, + audio: Input.Audio, + ) -> IO.NodeOutput: + audio_data_np = audio_tensor_to_contiguous_ndarray(audio["waveform"]) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, audio["sample_rate"], "mp4", "aac") + response = await sync_op_raw( + cls, + ApiEndpoint(path="/proxy/elevenlabs/v1/audio-isolation", method="POST"), + files={"audio": ("audio.mp4", audio_bytes_io, "audio/mp4")}, + content_type="multipart/form-data", + as_binary=True, + ) + return IO.NodeOutput(audio_bytes_to_audio_input(response)) + + +class ElevenLabsTextToSoundEffects(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ElevenLabsTextToSoundEffects", + display_name="ElevenLabs Text to Sound Effects", + category="api node/audio/ElevenLabs", + description="Generate sound effects from text descriptions.", + inputs=[ + IO.String.Input( + "text", + multiline=True, + default="", + tooltip="Text description of the sound effect to generate.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "eleven_sfx_v2", + [ + IO.Float.Input( + "duration", + default=5.0, + min=0.5, + max=30.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of generated sound in seconds.", + ), + IO.Boolean.Input( + "loop", + default=False, + tooltip="Create a smoothly looping sound effect.", + ), + IO.Float.Input( + "prompt_influence", + default=0.3, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="How closely generation follows the prompt. " + "Higher values make the sound follow the text more closely.", + ), + ], + ), + ], + tooltip="Model to use for sound effect generation.", + ), + IO.Combo.Input( + "output_format", + options=["mp3_44100_192", "opus_48000_192"], + tooltip="Audio output format.", + ), + ], + outputs=[ + IO.Audio.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14,"format":{"approximate":true,"suffix":"/minute"}}""", + ), + ) + + @classmethod + async def execute( + cls, + text: str, + model: dict, + output_format: str, + ) -> IO.NodeOutput: + validate_string(text, min_length=1) + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/elevenlabs/v1/sound-generation", + method="POST", + query_params={"output_format": output_format}, + ), + data=TextToSoundEffectsRequest( + text=text, + duration_seconds=model["duration"], + prompt_influence=model["prompt_influence"], + loop=model.get("loop", None), + ), + as_binary=True, + ) + return IO.NodeOutput(audio_bytes_to_audio_input(response)) + + +class ElevenLabsInstantVoiceClone(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ElevenLabsInstantVoiceClone", + display_name="ElevenLabs Instant Voice Clone", + category="api node/audio/ElevenLabs", + description="Create a cloned voice from audio samples. " + "Provide 1-8 audio recordings of the voice to clone.", + inputs=[ + IO.Autogrow.Input( + "files", + template=IO.Autogrow.TemplatePrefix( + IO.Audio.Input("audio"), + prefix="audio", + min=1, + max=8, + ), + tooltip="Audio recordings for voice cloning.", + ), + IO.Boolean.Input( + "remove_background_noise", + default=False, + tooltip="Remove background noise from voice samples using audio isolation.", + ), + ], + outputs=[ + IO.Custom(ELEVENLABS_VOICE).Output(display_name="voice"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr="""{"type":"usd","usd":0.15}"""), + ) + + @classmethod + async def execute( + cls, + files: IO.Autogrow.Type, + remove_background_noise: bool, + ) -> IO.NodeOutput: + file_tuples: list[tuple[str, tuple[str, bytes, str]]] = [] + for key in files: + audio = files[key] + sample_rate: int = audio["sample_rate"] + waveform = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, "mp4", "aac") + file_tuples.append(("files", (f"{key}.mp4", audio_bytes_io.getvalue(), "audio/mp4"))) + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/elevenlabs/v1/voices/add", method="POST"), + response_model=AddVoiceResponse, + data=AddVoiceRequest( + name=str(uuid.uuid4()), + remove_background_noise=remove_background_noise, + ), + files=file_tuples, + content_type="multipart/form-data", + ) + return IO.NodeOutput(response.voice_id) + + +ELEVENLABS_STS_VOICE_SETTINGS = [ + IO.Float.Input( + "speed", + default=1.0, + min=0.7, + max=1.3, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Speech speed. 1.0 is normal, <1.0 slower, >1.0 faster.", + ), + IO.Float.Input( + "similarity_boost", + default=0.75, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Similarity boost. Higher values make the voice more similar to the original.", + ), + IO.Boolean.Input( + "use_speaker_boost", + default=False, + tooltip="Boost similarity to the original speaker voice.", + ), + IO.Float.Input( + "style", + default=0.0, + min=0.0, + max=0.2, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Style exaggeration. Higher values increase stylistic expression but may reduce stability.", + ), +] + + +class ElevenLabsSpeechToSpeech(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ElevenLabsSpeechToSpeech", + display_name="ElevenLabs Speech to Speech", + category="api node/audio/ElevenLabs", + description="Transform speech from one voice to another while preserving the original content and emotion.", + inputs=[ + IO.Custom(ELEVENLABS_VOICE).Input( + "voice", + tooltip="Target voice for the transformation. " + "Connect from Voice Selector or Instant Voice Clone.", + ), + IO.Audio.Input( + "audio", + tooltip="Source audio to transform.", + ), + IO.Float.Input( + "stability", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.slider, + tooltip="Voice stability. Lower values give broader emotional range, " + "higher values produce more consistent but potentially monotonous speech.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "eleven_multilingual_sts_v2", + ELEVENLABS_STS_VOICE_SETTINGS, + ), + IO.DynamicCombo.Option( + "eleven_english_sts_v2", + ELEVENLABS_STS_VOICE_SETTINGS, + ), + ], + tooltip="Model to use for speech-to-speech transformation.", + ), + IO.Combo.Input( + "output_format", + options=["mp3_44100_192", "opus_48000_192"], + tooltip="Audio output format.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=4294967295, + tooltip="Seed for reproducibility.", + ), + IO.Boolean.Input( + "remove_background_noise", + default=False, + tooltip="Remove background noise from input audio using audio isolation.", + ), + ], + outputs=[ + IO.Audio.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/minute"}}""", + ), + ) + + @classmethod + async def execute( + cls, + voice: str, + audio: Input.Audio, + stability: float, + model: dict, + output_format: str, + seed: int, + remove_background_noise: bool, + ) -> IO.NodeOutput: + audio_data_np = audio_tensor_to_contiguous_ndarray(audio["waveform"]) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, audio["sample_rate"], "mp4", "aac") + voice_settings = TextToSpeechVoiceSettings( + stability=stability, + similarity_boost=model["similarity_boost"], + style=model["style"], + use_speaker_boost=model["use_speaker_boost"], + speed=model["speed"], + ) + response = await sync_op_raw( + cls, + ApiEndpoint( + path=f"/proxy/elevenlabs/v1/speech-to-speech/{voice}", + method="POST", + query_params={"output_format": output_format}, + ), + data=SpeechToSpeechRequest( + model_id=model["model"], + voice_settings=voice_settings.model_dump_json(exclude_none=True), + seed=seed, + remove_background_noise=remove_background_noise, + ), + files={"audio": ("audio.mp4", audio_bytes_io.getvalue(), "audio/mp4")}, + content_type="multipart/form-data", + as_binary=True, + ) + return IO.NodeOutput(audio_bytes_to_audio_input(response)) + + +def _generate_dialogue_inputs(count: int) -> list: + """Generate input widgets for a given number of dialogue entries.""" + inputs = [] + for i in range(1, count + 1): + inputs.extend( + [ + IO.String.Input( + f"text{i}", + multiline=True, + default="", + tooltip=f"Text content for dialogue entry {i}.", + ), + IO.Custom(ELEVENLABS_VOICE).Input( + f"voice{i}", + tooltip=f"Voice for dialogue entry {i}. Connect from Voice Selector or Instant Voice Clone.", + ), + ] + ) + return inputs + + +class ElevenLabsTextToDialogue(IO.ComfyNode): + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ElevenLabsTextToDialogue", + display_name="ElevenLabs Text to Dialogue", + category="api node/audio/ElevenLabs", + description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.", + inputs=[ + IO.Float.Input( + "stability", + default=0.5, + min=0.0, + max=1.0, + step=0.5, + display_mode=IO.NumberDisplay.slider, + tooltip="Voice stability. Lower values give broader emotional range, " + "higher values produce more consistent but potentially monotonous speech.", + ), + IO.Combo.Input( + "apply_text_normalization", + options=["auto", "on", "off"], + tooltip="Text normalization mode. 'auto' lets the system decide, " + "'on' always applies normalization, 'off' skips it.", + ), + IO.Combo.Input( + "model", + options=["eleven_v3"], + tooltip="Model to use for dialogue generation.", + ), + IO.DynamicCombo.Input( + "inputs", + options=[ + IO.DynamicCombo.Option("1", _generate_dialogue_inputs(1)), + IO.DynamicCombo.Option("2", _generate_dialogue_inputs(2)), + IO.DynamicCombo.Option("3", _generate_dialogue_inputs(3)), + IO.DynamicCombo.Option("4", _generate_dialogue_inputs(4)), + IO.DynamicCombo.Option("5", _generate_dialogue_inputs(5)), + IO.DynamicCombo.Option("6", _generate_dialogue_inputs(6)), + IO.DynamicCombo.Option("7", _generate_dialogue_inputs(7)), + IO.DynamicCombo.Option("8", _generate_dialogue_inputs(8)), + IO.DynamicCombo.Option("9", _generate_dialogue_inputs(9)), + IO.DynamicCombo.Option("10", _generate_dialogue_inputs(10)), + ], + tooltip="Number of dialogue entries.", + ), + IO.String.Input( + "language_code", + default="", + tooltip="ISO-639-1 or ISO-639-3 language code (e.g., 'en', 'es', 'fra'). " + "Leave empty for automatic detection.", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=4294967295, + tooltip="Seed for reproducibility.", + ), + IO.Combo.Input( + "output_format", + options=["mp3_44100_192", "opus_48000_192"], + tooltip="Audio output format.", + ), + ], + outputs=[ + IO.Audio.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.24,"format":{"approximate":true,"suffix":"/1K chars"}}""", + ), + ) + + @classmethod + async def execute( + cls, + stability: float, + apply_text_normalization: str, + model: str, + inputs: dict, + language_code: str, + seed: int, + output_format: str, + ) -> IO.NodeOutput: + num_entries = int(inputs["inputs"]) + dialogue_inputs: list[DialogueInput] = [] + for i in range(1, num_entries + 1): + text = inputs[f"text{i}"] + voice_id = inputs[f"voice{i}"] + validate_string(text, min_length=1) + dialogue_inputs.append(DialogueInput(text=text, voice_id=voice_id)) + request = TextToDialogueRequest( + inputs=dialogue_inputs, + model_id=model, + language_code=language_code if language_code.strip() else None, + settings=DialogueSettings(stability=stability), + seed=seed, + apply_text_normalization=apply_text_normalization, + ) + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/elevenlabs/v1/text-to-dialogue", + method="POST", + query_params={"output_format": output_format}, + ), + data=request, + as_binary=True, + ) + return IO.NodeOutput(audio_bytes_to_audio_input(response)) + + +class ElevenLabsExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ElevenLabsSpeechToText, + ElevenLabsVoiceSelector, + ElevenLabsTextToSpeech, + ElevenLabsAudioIsolation, + ElevenLabsTextToSoundEffects, + ElevenLabsInstantVoiceClone, + ElevenLabsSpeechToSpeech, + ElevenLabsTextToDialogue, + ] + + +async def comfy_entrypoint() -> ElevenLabsExtension: + return ElevenLabsExtension() diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index ad0f4b4d1..25d747e76 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer import base64 import os from enum import Enum +from fnmatch import fnmatch from io import BytesIO from typing import Literal @@ -14,7 +15,7 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input, Types -from comfy_api_nodes.apis.gemini_api import ( +from comfy_api_nodes.apis.gemini import ( GeminiContent, GeminiFileData, GeminiGenerateContentRequest, @@ -28,12 +29,14 @@ from comfy_api_nodes.apis.gemini_api import ( GeminiRole, GeminiSystemInstructionContent, GeminiTextPart, + GeminiThinkingConfig, Modality, ) from comfy_api_nodes.util import ( ApiEndpoint, audio_to_base64_string, bytesio_to_image_tensor, + download_url_to_image_tensor, get_number_of_images, sync_op, tensor_to_base64_string, @@ -53,17 +56,20 @@ GEMINI_IMAGE_SYS_PROMPT = ( "Prioritize generating the visual representation above any text, formatting, or conversational requests." ) - -class GeminiModel(str, Enum): - """ - Gemini Model Names allowed by comfy-api - """ - - gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06" - gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" - gemini_2_5_pro = "gemini-2.5-pro" - gemini_2_5_flash = "gemini-2.5-flash" - gemini_3_0_pro = "gemini-3-pro-preview" +GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution"]), + expr=""" + ( + $m := widgets.model; + $r := widgets.resolution; + $isFlash := $contains($m, "nano banana 2"); + $flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154}; + $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; + $prices := $isFlash ? $flashPrices : $proPrices; + {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} + ) + """, +) class GeminiImageModel(str, Enum): @@ -118,6 +124,13 @@ async def create_image_parts( return image_parts +def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool: + """Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*').""" + if mime is None: + return False + return fnmatch(mime.value, pattern) + + def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]: """ Filter response parts by their type. @@ -129,7 +142,7 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera Returns: List of response parts matching the requested type. """ - if response.candidates is None: + if not response.candidates: if response.promptFeedback and response.promptFeedback.blockReason: feedback = response.promptFeedback raise ValueError( @@ -140,12 +153,24 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." ) parts = [] - for part in response.candidates[0].content.parts: - if part_type == "text" and hasattr(part, "text") and part.text: - parts.append(part) - elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: - parts.append(part) - # Skip parts that don't match the requested type + blocked_reasons = [] + for candidate in response.candidates: + if candidate.finishReason and candidate.finishReason.upper() == "IMAGE_PROHIBITED_CONTENT": + blocked_reasons.append(candidate.finishReason) + continue + if candidate.content is None or candidate.content.parts is None: + continue + for part in candidate.content.parts: + if part_type == "text" and part.text: + parts.append(part) + elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type): + parts.append(part) + elif part.fileData and _mime_matches(part.fileData.mimeType, part_type): + parts.append(part) + + if not parts and blocked_reasons: + raise ValueError(f"Gemini API blocked the request. Reasons: {blocked_reasons}") + return parts @@ -163,12 +188,17 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image: image_tensors: list[Input.Image] = [] - parts = get_parts_by_type(response, "image/png") + parts = get_parts_by_type(response, "image/*") for part in parts: - image_data = base64.b64decode(part.inlineData.data) - returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + if (part.thought is True) != thought: + continue + if part.inlineData: + image_data = base64.b64decode(part.inlineData.data) + returned_image = bytesio_to_image_tensor(BytesIO(image_data)) + else: + returned_image = await download_url_to_image_tensor(part.fileData.fileUri) image_tensors.append(returned_image) if len(image_tensors) == 0: return torch.zeros((1, 1024, 1024, 4)) @@ -197,14 +227,22 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N input_tokens_price = 0.30 output_text_tokens_price = 2.50 output_image_tokens_price = 30.0 - elif response.modelVersion == "gemini-3-pro-preview": + elif response.modelVersion in ("gemini-3-pro-preview", "gemini-3.1-pro-preview"): input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 0.0 + elif response.modelVersion == "gemini-3.1-flash-lite-preview": + input_tokens_price = 0.25 + output_text_tokens_price = 1.50 + output_image_tokens_price = 0.0 elif response.modelVersion == "gemini-3-pro-image-preview": input_tokens_price = 2 output_text_tokens_price = 12.0 output_image_tokens_price = 120.0 + elif response.modelVersion == "gemini-3.1-flash-image-preview": + input_tokens_price = 0.5 + output_text_tokens_price = 3.0 + output_image_tokens_price = 60.0 else: return None final_price = response.usageMetadata.promptTokenCount * input_tokens_price @@ -248,8 +286,16 @@ class GeminiNode(IO.ComfyNode): ), IO.Combo.Input( "model", - options=GeminiModel, - default=GeminiModel.gemini_2_5_pro, + options=[ + "gemini-2.5-pro-preview-05-06", + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-pro", + "gemini-2.5-flash", + "gemini-3-pro-preview", + "gemini-3-1-pro", + "gemini-3-1-flash-lite", + ], + default="gemini-3-1-pro", tooltip="The Gemini model to use for generating responses.", ), IO.Int.Input( @@ -292,6 +338,7 @@ class GeminiNode(IO.ComfyNode): default="", optional=True, tooltip="Foundational instructions that dictate an AI's behavior.", + advanced=True, ), ], outputs=[ @@ -303,6 +350,35 @@ class GeminiNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "gemini-2.5-flash") ? { + "type": "list_usd", + "usd": [0.0003, 0.0025], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens"} + } + : $contains($m, "gemini-2.5-pro") ? { + "type": "list_usd", + "usd": [0.00125, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? { + "type": "list_usd", + "usd": [0.002, 0.012], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gemini-3-1-flash-lite") ? { + "type": "list_usd", + "usd": [0.00025, 0.0015], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : {"type":"text", "text":"Token-based"} + ) + """, + ), ) @classmethod @@ -367,12 +443,14 @@ class GeminiNode(IO.ComfyNode): files: list[GeminiPart] | None = None, system_prompt: str = "", ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False) + if model == "gemini-3-pro-preview": + model = "gemini-3.1-pro-preview" # model "gemini-3-pro-preview" will be soon deprecated by Google + elif model == "gemini-3-1-pro": + model = "gemini-3.1-pro-preview" + elif model == "gemini-3-1-flash-lite": + model = "gemini-3.1-flash-lite-preview" - # Create parts list with text prompt as the first part parts: list[GeminiPart] = [GeminiPart(text=prompt)] - - # Add other modal parts if images is not None: parts.extend(await create_image_parts(cls, images)) if audio is not None: @@ -545,6 +623,7 @@ class GeminiImage(IO.ComfyNode): tooltip="Choose 'IMAGE' for image-only output, or " "'IMAGE+TEXT' to return both the generated image and a text response.", optional=True, + advanced=True, ), IO.String.Input( "system_prompt", @@ -552,6 +631,7 @@ class GeminiImage(IO.ComfyNode): default=GEMINI_IMAGE_SYS_PROMPT, optional=True, tooltip="Foundational instructions that dictate an AI's behavior.", + advanced=True, ), ], outputs=[ @@ -564,6 +644,9 @@ class GeminiImage(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.039,"format":{"suffix":"/Image (1K)","approximate":true}}""", + ), ) @classmethod @@ -583,7 +666,7 @@ class GeminiImage(IO.ComfyNode): if not aspect_ratio: aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December - image_config = GeminiImageConfig(aspectRatio=aspect_ratio) + image_config = GeminiImageConfig() if aspect_ratio == "auto" else GeminiImageConfig(aspectRatio=aspect_ratio) if images is not None: parts.extend(await create_image_parts(cls, images)) @@ -596,21 +679,21 @@ class GeminiImage(IO.ComfyNode): response = await sync_op( cls, - endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), ], generationConfig=GeminiImageGenerationConfig( responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), - imageConfig=None if aspect_ratio == "auto" else image_config, + imageConfig=image_config, ), systemInstruction=gemini_system_prompt, ), response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) class GeminiImage2(IO.ComfyNode): @@ -632,7 +715,7 @@ class GeminiImage2(IO.ComfyNode): ), IO.Combo.Input( "model", - options=["gemini-3-pro-image-preview"], + options=["gemini-3-pro-image-preview", "Nano Banana 2 (Gemini 3.1 Flash Image)"], ), IO.Int.Input( "seed", @@ -663,6 +746,7 @@ class GeminiImage2(IO.ComfyNode): options=["IMAGE+TEXT", "IMAGE"], tooltip="Choose 'IMAGE' for image-only output, or " "'IMAGE+TEXT' to return both the generated image and a text response.", + advanced=True, ), IO.Image.Input( "images", @@ -682,6 +766,7 @@ class GeminiImage2(IO.ComfyNode): default=GEMINI_IMAGE_SYS_PROMPT, optional=True, tooltip="Foundational instructions that dictate an AI's behavior.", + advanced=True, ), ], outputs=[ @@ -694,6 +779,7 @@ class GeminiImage2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, ) @classmethod @@ -710,6 +796,8 @@ class GeminiImage2(IO.ComfyNode): system_prompt: str = "", ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" parts: list[GeminiPart] = [GeminiPart(text=prompt)] if images is not None: @@ -729,7 +817,7 @@ class GeminiImage2(IO.ComfyNode): response = await sync_op( cls, - ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), data=GeminiImageGenerateContentRequest( contents=[ GeminiContent(role=GeminiRole.user, parts=parts), @@ -743,7 +831,179 @@ class GeminiImage2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) + + +class GeminiNanoBanana2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiNanoBanana2", + display_name="Nano Banana 2", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.Combo.Input( + "model", + options=["Nano Banana 2 (Gemini 3.1 Flash Image)"], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "4:5", + "5:4", + "9:16", + "16:9", + "21:9", + # "1:4", + # "4:1", + # "8:1", + # "1:8", + ], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, a 16:9 square is usually generated.", + ), + IO.Combo.Input( + "resolution", + options=[ + # "512px", + "1K", + "2K", + "4K", + ], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE", "IMAGE+TEXT"], + advanced=True, + ), + IO.Combo.Input( + "thinking_level", + options=["MINIMAL", "HIGH"], + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional reference image(s). " + "To include multiple images, use the Batch Images node (up to 14).", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + advanced=True, + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + IO.Image.Output( + display_name="thought_image", + tooltip="First image from the model's thinking process. " + "Only available with thinking_level HIGH and IMAGE+TEXT modality.", + ), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=GEMINI_IMAGE_2_PRICE_BADGE, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + aspect_ratio: str, + resolution: str, + response_modalities: str, + thinking_level: str, + images: Input.Image | None = None, + files: list[GeminiPart] | None = None, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "Nano Banana 2 (Gemini 3.1 Flash Image)": + model = "gemini-3.1-flash-image-preview" + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images is not None: + if get_number_of_images(images) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(await create_image_parts(cls, images)) + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=resolution) + if aspect_ratio != "auto": + image_config.aspectRatio = aspect_ratio + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/vertexai/gemini/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + thinkingConfig=GeminiThinkingConfig(thinkingLevel=thinking_level), + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput( + await get_image_from_response(response), + get_text_from_response(response), + await get_image_from_response(response, thought=True), + ) class GeminiExtension(ComfyExtension): @@ -753,6 +1013,7 @@ class GeminiExtension(ComfyExtension): GeminiNode, GeminiImage, GeminiImage2, + GeminiNanoBanana2, GeminiInputFiles, ] diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py new file mode 100644 index 000000000..dabc899d6 --- /dev/null +++ b/comfy_api_nodes/nodes_grok.py @@ -0,0 +1,728 @@ +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.grok import ( + ImageEditRequest, + ImageGenerationRequest, + ImageGenerationResponse, + InputUrlObject, + VideoEditRequest, + VideoExtensionRequest, + VideoGenerationRequest, + VideoGenerationResponse, + VideoStatusResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_fs_object_size, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + upload_images_to_comfyapi, + upload_video_to_comfyapi, + validate_string, + validate_video_duration, +) + + +def _extract_grok_price(response) -> float | None: + if response.usage and response.usage.cost_in_usd_ticks is not None: + return response.usage.cost_in_usd_ticks / 10_000_000_000 + return None + + +def _extract_grok_video_price(response) -> float | None: + price = _extract_grok_price(response) + if price is not None: + return price * 1.43 + return None + + +class GrokImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokImageNode", + display_name="Grok Image", + category="api node/image/Grok", + description="Generate images using Grok based on a text prompt", + inputs=[ + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + ), + IO.Int.Input( + "number_of_images", + default=1, + min=1, + max=10, + step=1, + tooltip="Number of images to generate", + display_mode=IO.NumberDisplay.number, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Combo.Input("resolution", options=["1K", "2K"], optional=True), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": $rate * widgets.number_of_images} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + aspect_ratio: str, + number_of_images: int, + seed: int, + resolution: str = "1K", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/images/generations", method="POST"), + data=ImageGenerationRequest( + model=model, + prompt=prompt, + aspect_ratio=aspect_ratio, + n=number_of_images, + seed=seed, + resolution=resolution.lower(), + ), + response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, + ) + if len(response.data) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) + return IO.NodeOutput( + torch.cat( + [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], + ) + ) + + +class GrokImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokImageEditNode", + display_name="Grok Image Edit", + category="api node/image/Grok", + description="Modify an existing image based on a text prompt", + inputs=[ + IO.Combo.Input( + "model", + options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"], + ), + IO.Image.Input("image", display_name="images"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="The text prompt used to generate the image", + ), + IO.Combo.Input("resolution", options=["1K", "2K"]), + IO.Int.Input( + "number_of_images", + default=1, + min=1, + max=10, + step=1, + tooltip="Number of edited images to generate", + display_mode=IO.NumberDisplay.number, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Combo.Input( + "aspect_ratio", + options=[ + "auto", + "1:1", + "2:3", + "3:2", + "3:4", + "4:3", + "9:16", + "16:9", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", + ], + optional=True, + tooltip="Only allowed when multiple images are connected to the image input.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]), + expr=""" + ( + $rate := $contains(widgets.model, "pro") ? 0.07 : 0.02; + {"type":"usd","usd": 0.002 + $rate * widgets.number_of_images} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + resolution: str, + number_of_images: int, + seed: int, + aspect_ratio: str = "auto", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + if model == "grok-imagine-image-pro": + if get_number_of_images(image) > 1: + raise ValueError("The pro model supports only 1 input image.") + elif get_number_of_images(image) > 3: + raise ValueError("A maximum of 3 input images is supported.") + if aspect_ratio != "auto" and get_number_of_images(image) == 1: + raise ValueError( + "Custom aspect ratio is only allowed when multiple images are connected to the image input." + ) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"), + data=ImageEditRequest( + model=model, + images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image], + prompt=prompt, + resolution=resolution.lower(), + n=number_of_images, + seed=seed, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, + ), + response_model=ImageGenerationResponse, + price_extractor=_extract_grok_price, + ) + if len(response.data) == 1: + return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url)) + return IO.NodeOutput( + torch.cat( + [await download_url_to_image_tensor(i) for i in [str(d.url) for d in response.data if d.url]], + ) + ) + + +class GrokVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoNode", + display_name="Grok Video", + category="api node/video/Grok", + description="Generate video from a prompt or an image", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=6, + min=1, + max=15, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Image.Input("image", optional=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), + expr=""" + ( + $rate := widgets.resolution = "720p" ? 0.07 : 0.05; + $base := $rate * widgets.duration; + {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + image: Input.Image | None = None, + ) -> IO.NodeOutput: + if model == "grok-imagine-video-beta": + model = "grok-imagine-video" + image_url = None + if image is not None: + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + image_url = InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}") + validate_string(prompt, strip_whitespace=True, min_length=1) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), + data=VideoGenerationRequest( + model=model, + image=image_url, + prompt=prompt, + resolution=resolution, + duration=duration, + aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio, + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokVideoEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoEditNode", + display_name="Grok Video Edit", + category="api node/video/Grok", + description="Edit an existing video based on a text prompt.", + inputs=[ + IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.Video.Input("video", tooltip="Maximum supported duration is 8.7 seconds and 50MB file size."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + video: Input.Video, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + validate_video_duration(video, min_duration=1, max_duration=8.7) + video_stream = video.get_stream_source() + video_size = get_fs_object_size(video_stream) + if video_size > 50 * 1024 * 1024: + raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/edits", method="POST"), + data=VideoEditRequest( + model=model, + video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), + prompt=prompt, + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + price_extractor=_extract_grok_price, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokVideoReferenceNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoReferenceNode", + display_name="Grok Reference-to-Video", + category="api node/video/Grok", + description="Generate video guided by reference images as style and content references.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of the desired video.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "grok-imagine-video", + [ + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="reference_", + min=1, + max=7, + ), + tooltip="Up to 7 reference images to guide the video generation.", + ), + IO.Combo.Input( + "resolution", + options=["480p", "720p"], + tooltip="The resolution of the output video.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "4:3", "3:2", "1:1", "2:3", "3:4", "9:16"], + tooltip="The aspect ratio of the output video.", + ), + IO.Int.Input( + "duration", + default=6, + min=2, + max=10, + step=1, + tooltip="The duration of the output video in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + ], + tooltip="The model to use for video generation.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model.duration", "model.resolution"], + input_groups=["model.reference_images"], + ), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $dur := $lookup(widgets, "model.duration"); + $refs := inputGroups["model.reference_images"]; + $rate := $res = "720p" ? 0.07 : 0.05; + $price := ($rate * $dur + 0.002 * $refs) * 1.43; + {"type":"usd","usd": $price} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + ref_image_urls = await upload_images_to_comfyapi( + cls, + list(model["reference_images"].values()), + mime_type="image/png", + wait_label="Uploading base images", + max_images=7, + ) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), + data=VideoGenerationRequest( + model=model["model"], + reference_images=[InputUrlObject(url=i) for i in ref_image_urls], + prompt=prompt, + resolution=model["resolution"], + duration=model["duration"], + aspect_ratio=model["aspect_ratio"], + seed=seed, + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + price_extractor=_extract_grok_video_price, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokVideoExtendNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GrokVideoExtendNode", + display_name="Grok Video Extend", + category="api node/video/Grok", + description="Extend an existing video with a seamless continuation based on a text prompt.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text description of what should happen next in the video.", + ), + IO.Video.Input("video", tooltip="Source video to extend. MP4 format, 2-15 seconds."), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "grok-imagine-video", + [ + IO.Int.Input( + "duration", + default=8, + min=2, + max=10, + step=1, + tooltip="Length of the extension in seconds.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + ], + tooltip="The model to use for video extension.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model.duration"]), + expr=""" + ( + $dur := $lookup(widgets, "model.duration"); + { + "type": "range_usd", + "min_usd": (0.02 + 0.05 * $dur) * 1.43, + "max_usd": (0.15 + 0.05 * $dur) * 1.43 + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + video: Input.Video, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + validate_video_duration(video, min_duration=2, max_duration=15) + video_size = get_fs_object_size(video.get_stream_source()) + if video_size > 50 * 1024 * 1024: + raise ValueError(f"Video size ({video_size / 1024 / 1024:.1f}MB) exceeds 50MB limit.") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/xai/v1/videos/extensions", method="POST"), + data=VideoExtensionRequest( + prompt=prompt, + video=InputUrlObject(url=await upload_video_to_comfyapi(cls, video)), + duration=model["duration"], + ), + response_model=VideoGenerationResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), + status_extractor=lambda r: r.status if r.status is not None else "complete", + response_model=VideoStatusResponse, + price_extractor=_extract_grok_video_price, + ) + return IO.NodeOutput(await download_url_to_video_output(response.video.url)) + + +class GrokExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + GrokImageNode, + GrokImageEditNode, + GrokVideoNode, + GrokVideoReferenceNode, + GrokVideoEditNode, + GrokVideoExtendNode, + ] + + +async def comfy_entrypoint() -> GrokExtension: + return GrokExtension() diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py new file mode 100644 index 000000000..488080a74 --- /dev/null +++ b/comfy_api_nodes/nodes_hitpaw.py @@ -0,0 +1,342 @@ +import math + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.hitpaw import ( + ImageEnhanceTaskCreateRequest, + InputVideoModel, + TaskCreateDataResponse, + TaskCreateResponse, + TaskStatusPollRequest, + TaskStatusResponse, + VideoEnhanceTaskCreateRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + downscale_image_tensor, + get_image_dimensions, + poll_op, + sync_op, + upload_image_to_comfyapi, + upload_video_to_comfyapi, + validate_video_duration, +) + +VIDEO_MODELS_MODELS_MAP = { + "Portrait Restore Model (1x)": "portrait_restore_1x", + "Portrait Restore Model (2x)": "portrait_restore_2x", + "General Restore Model (1x)": "general_restore_1x", + "General Restore Model (2x)": "general_restore_2x", + "General Restore Model (4x)": "general_restore_4x", + "Ultra HD Model (2x)": "ultrahd_restore_2x", + "Generative Model (1x)": "generative_1x", +} + +# Resolution name to target dimension (shorter side) in pixels +RESOLUTION_TARGET_MAP = { + "720p": 720, + "1080p": 1080, + "2K/QHD": 1440, + "4K/UHD": 2160, + "8K": 4320, +} + +# Square (1:1) resolutions use standard square dimensions +RESOLUTION_SQUARE_MAP = { + "720p": 720, + "1080p": 1080, + "2K/QHD": 1440, + "4K/UHD": 2048, # DCI 4K square + "8K": 4096, # DCI 8K square +} + +# Models with limited resolution support (no 8K) +LIMITED_RESOLUTION_MODELS = {"Generative Model (1x)"} + +# Resolution options for different model types +RESOLUTIONS_LIMITED = ["original", "720p", "1080p", "2K/QHD", "4K/UHD"] +RESOLUTIONS_FULL = ["original", "720p", "1080p", "2K/QHD", "4K/UHD", "8K"] + +# Maximum output resolution in pixels +MAX_PIXELS_GENERATIVE = 32_000_000 +MAX_MP_GENERATIVE = MAX_PIXELS_GENERATIVE // 1_000_000 + + +class HitPawGeneralImageEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="HitPawGeneralImageEnhance", + display_name="HitPaw General Image Enhance", + category="api node/image/HitPaw", + description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. " + f"Maximum output: {MAX_MP_GENERATIVE} megapixels.", + inputs=[ + IO.Combo.Input("model", options=["generative_portrait", "generative"]), + IO.Image.Input("image"), + IO.Combo.Input("upscale_factor", options=[1, 2, 4]), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed the limit.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $prices := { + "generative_portrait": {"min": 0.02, "max": 0.06}, + "generative": {"min": 0.05, "max": 0.15} + }; + $price := $lookup($prices, widgets.model); + { + "type": "range_usd", + "min_usd": $price.min, + "max_usd": $price.max + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + upscale_factor: int, + auto_downscale: bool, + ) -> IO.NodeOutput: + height, width = get_image_dimensions(image) + requested_scale = upscale_factor + output_pixels = height * width * requested_scale * requested_scale + if output_pixels > MAX_PIXELS_GENERATIVE: + if auto_downscale: + input_pixels = width * height + scale = 1 + max_input_pixels = MAX_PIXELS_GENERATIVE + + for candidate in [4, 2, 1]: + if candidate > requested_scale: + continue + scale_output_pixels = input_pixels * candidate * candidate + if scale_output_pixels <= MAX_PIXELS_GENERATIVE: + scale = candidate + max_input_pixels = None + break + # Check if we can downscale input by at most 2x to fit + downscale_ratio = math.sqrt(scale_output_pixels / MAX_PIXELS_GENERATIVE) + if downscale_ratio <= 2.0: + scale = candidate + max_input_pixels = MAX_PIXELS_GENERATIVE // (candidate * candidate) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + upscale_factor = scale + else: + output_width = width * requested_scale + output_height = height * requested_scale + raise ValueError( + f"Output size ({output_width}x{output_height} = {output_pixels:,} pixels) " + f"exceeds maximum allowed size of {MAX_PIXELS_GENERATIVE:,} pixels ({MAX_MP_GENERATIVE}MP). " + f"Enable auto_downscale or use a smaller input image or a lower upscale factor." + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/photo-enhancer", method="POST"), + response_model=TaskCreateResponse, + data=ImageEnhanceTaskCreateRequest( + model_name=f"{model}_{upscale_factor}x", + img_url=await upload_image_to_comfyapi(cls, image, total_pixels=None), + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + if initial_res.code != 200: + raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}") + request_price = initial_res.data.consume_coins / 1000 + final_response = await poll_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"), + data=TaskCreateDataResponse(job_id=initial_res.data.job_id), + response_model=TaskStatusResponse, + status_extractor=lambda x: x.data.status, + price_extractor=lambda x: request_price, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url)) + + +class HitPawVideoEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + model_options = [] + for model_name in VIDEO_MODELS_MODELS_MAP: + if model_name in LIMITED_RESOLUTION_MODELS: + resolutions = RESOLUTIONS_LIMITED + else: + resolutions = RESOLUTIONS_FULL + model_options.append( + IO.DynamicCombo.Option( + model_name, + [IO.Combo.Input("resolution", options=resolutions)], + ) + ) + + return IO.Schema( + node_id="HitPawVideoEnhance", + display_name="HitPaw Video Enhance", + category="api node/video/HitPaw", + description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. " + "Prices shown are per second of video.", + inputs=[ + IO.DynamicCombo.Input("model", options=model_options), + IO.Video.Input("video"), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]), + expr=""" + ( + $m := $lookup(widgets, "model"); + $res := $lookup(widgets, "model.resolution"); + $standard_model_prices := { + "original": {"min": 0.01, "max": 0.198}, + "720p": {"min": 0.01, "max": 0.06}, + "1080p": {"min": 0.015, "max": 0.09}, + "2k/qhd": {"min": 0.02, "max": 0.117}, + "4k/uhd": {"min": 0.025, "max": 0.152}, + "8k": {"min": 0.033, "max": 0.198} + }; + $ultra_hd_model_prices := { + "original": {"min": 0.015, "max": 0.264}, + "720p": {"min": 0.015, "max": 0.092}, + "1080p": {"min": 0.02, "max": 0.12}, + "2k/qhd": {"min": 0.026, "max": 0.156}, + "4k/uhd": {"min": 0.034, "max": 0.203}, + "8k": {"min": 0.044, "max": 0.264} + }; + $generative_model_prices := { + "original": {"min": 0.015, "max": 0.338}, + "720p": {"min": 0.008, "max": 0.090}, + "1080p": {"min": 0.05, "max": 0.15}, + "2k/qhd": {"min": 0.038, "max": 0.225}, + "4k/uhd": {"min": 0.056, "max": 0.338} + }; + $prices := $contains($m, "ultra hd") ? $ultra_hd_model_prices : + $contains($m, "generative") ? $generative_model_prices : + $standard_model_prices; + $price := $lookup($prices, $res); + { + "type": "range_usd", + "min_usd": $price.min, + "max_usd": $price.max, + "format": {"approximate": true, "suffix": "/second"} + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: InputVideoModel, + video: Input.Video, + ) -> IO.NodeOutput: + validate_video_duration(video, min_duration=0.5, max_duration=60 * 60) + resolution = model["resolution"] + src_width, src_height = video.get_dimensions() + + if resolution == "original": + output_width = src_width + output_height = src_height + else: + if src_width == src_height: + target_size = RESOLUTION_SQUARE_MAP[resolution] + if target_size < src_width: + raise ValueError( + f"Selected resolution {resolution} ({target_size}x{target_size}) is smaller than " + f"the input video ({src_width}x{src_height}). Please select a higher resolution or 'original'." + ) + output_width = target_size + output_height = target_size + else: + min_dimension = min(src_width, src_height) + target_size = RESOLUTION_TARGET_MAP[resolution] + if target_size < min_dimension: + raise ValueError( + f"Selected resolution {resolution} ({target_size}p) is smaller than " + f"the input video's shorter dimension ({min_dimension}p). " + f"Please select a higher resolution or 'original'." + ) + if src_width > src_height: + output_height = target_size + output_width = int(target_size * (src_width / src_height)) + else: + output_width = target_size + output_height = int(target_size * (src_height / src_width)) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/video-enhancer", method="POST"), + response_model=TaskCreateResponse, + data=VideoEnhanceTaskCreateRequest( + video_url=await upload_video_to_comfyapi(cls, video), + resolution=[output_width, output_height], + original_resolution=[src_width, src_height], + model_name=VIDEO_MODELS_MODELS_MAP[model["model"]], + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + request_price = initial_res.data.consume_coins / 1000 + if initial_res.code != 200: + raise ValueError(f"Task creation failed with code {initial_res.code}: {initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path="/proxy/hitpaw/api/task-status", method="POST"), + data=TaskStatusPollRequest(job_id=initial_res.data.job_id), + response_model=TaskStatusResponse, + status_extractor=lambda x: x.data.status, + price_extractor=lambda x: request_price, + poll_interval=10.0, + max_poll_attempts=320, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url)) + + +class HitPawExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + HitPawGeneralImageEnhance, + HitPawVideoEnhance, + ] + + +async def comfy_entrypoint() -> HitPawExtension: + return HitPawExtension() diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py new file mode 100644 index 000000000..753c09b6e --- /dev/null +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -0,0 +1,743 @@ +import zipfile +from io import BytesIO + +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input, Types +from comfy_api_nodes.apis.hunyuan3d import ( + Hunyuan3DViewImage, + InputGenerateType, + ResultFile3D, + SmartTopologyRequest, + TaskFile3DInput, + TextureEditTaskRequest, + To3DPartTaskRequest, + To3DProTaskCreateResponse, + To3DProTaskQueryRequest, + To3DProTaskRequest, + To3DProTaskResultResponse, + To3DUVTaskRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + bytesio_to_image_tensor, + download_url_to_bytesio, + download_url_to_file_3d, + download_url_to_image_tensor, + downscale_image_tensor_by_max_side, + poll_op, + sync_op, + upload_3d_model_to_comfyapi, + upload_image_to_comfyapi, + validate_image_dimensions, + validate_string, +) + + +def _is_tencent_rate_limited(status: int, body: object) -> bool: + return ( + status == 400 + and isinstance(body, dict) + and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", "")) + ) + + +class ObjZipResult: + __slots__ = ("obj", "texture", "metallic", "normal", "roughness") + + def __init__( + self, + obj: Types.File3D, + texture: Input.Image | None = None, + metallic: Input.Image | None = None, + normal: Input.Image | None = None, + roughness: Input.Image | None = None, + ): + self.obj = obj + self.texture = texture + self.metallic = metallic + self.normal = normal + self.roughness = roughness + + +async def download_and_extract_obj_zip(url: str) -> ObjZipResult: + """The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images. + + When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps + identified by their filename suffixes. + """ + data = BytesIO() + await download_url_to_bytesio(url, data) + data.seek(0) + if not zipfile.is_zipfile(data): + data.seek(0) + return ObjZipResult(obj=Types.File3D(source=data, file_format="obj")) + data.seek(0) + obj_bytes = None + textures: dict[str, Input.Image] = {} + with zipfile.ZipFile(data) as zf: + for name in zf.namelist(): + lower = name.lower() + if lower.endswith(".obj"): + obj_bytes = zf.read(name) + elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")): + stem = lower.rsplit(".", 1)[0] + tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB") + matched_key = "texture" + for suffix, key in { + "_metallic": "metallic", + "_normal": "normal", + "_roughness": "roughness", + }.items(): + if stem.endswith(suffix): + matched_key = key + break + textures[matched_key] = tensor + if obj_bytes is None: + raise ValueError("ZIP archive does not contain an OBJ file.") + return ObjZipResult( + obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"), + texture=textures.get("texture"), + metallic=textures.get("metallic"), + normal=textures.get("normal"), + roughness=textures.get("roughness"), + ) + + +def get_file_from_response( + response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True +) -> ResultFile3D | None: + for i in response_objs: + if i.Type.lower() == file_type.lower(): + return i + if raise_if_not_found: + raise ValueError(f"'{file_type}' file type is not found in the response.") + return None + + +class TencentTextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentTextToModelNode", + display_name="Hunyuan3D: Text to Model", + category="api node/3d/Tencent", + essentials_category="3D", + inputs=[ + IO.Combo.Input( + "model", + options=["3.0", "3.1"], + tooltip="The LowPoly option is unavailable for the `3.1` model.", + ), + IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."), + IO.Int.Input("face_count", default=500000, min=40000, max=1500000), + IO.DynamicCombo.Input( + "generate_type", + options=[ + IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]), + IO.DynamicCombo.Option( + "LowPoly", + [ + IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]), + IO.Boolean.Input("pbr", default=False), + ], + ), + IO.DynamicCombo.Option("Geometry", []), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DOBJ.Output(display_name="OBJ"), + IO.Image.Output(display_name="texture_image"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["generate_type", "generate_type.pbr", "face_count"]), + expr=""" + ( + $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; + $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; + $face := widgets.face_count != 500000 ? 10 : 0; + {"type":"usd","usd": ($base + $pbr + $face) * 0.02} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + face_count: int, + generate_type: InputGenerateType, + seed: int, + ) -> IO.NodeOutput: + _ = seed + validate_string(prompt, field_name="prompt", min_length=1, max_length=1024) + if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly": + raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DProTaskRequest( + Model=model, + Prompt=prompt, + FaceCount=face_count, + GenerateType=generate_type["generate_type"], + EnablePBR=generate_type.get("pbr", None), + PolygonType=generate_type.get("polygon_type", None), + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + task_id = response.JobId + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=task_id), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) + return IO.NodeOutput( + f"{task_id}.glb", + await download_url_to_file_3d( + get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id + ), + obj_result.obj, + obj_result.texture, + ) + + +class TencentImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentImageToModelNode", + display_name="Hunyuan3D: Image(s) to Model", + category="api node/3d/Tencent", + essentials_category="3D", + inputs=[ + IO.Combo.Input( + "model", + options=["3.0", "3.1"], + tooltip="The LowPoly option is unavailable for the `3.1` model.", + ), + IO.Image.Input("image"), + IO.Image.Input("image_left", optional=True), + IO.Image.Input("image_right", optional=True), + IO.Image.Input("image_back", optional=True), + IO.Int.Input("face_count", default=500000, min=40000, max=1500000), + IO.DynamicCombo.Input( + "generate_type", + options=[ + IO.DynamicCombo.Option("Normal", [IO.Boolean.Input("pbr", default=False)]), + IO.DynamicCombo.Option( + "LowPoly", + [ + IO.Combo.Input("polygon_type", options=["triangle", "quadrilateral"]), + IO.Boolean.Input("pbr", default=False), + ], + ), + IO.DynamicCombo.Option("Geometry", []), + ], + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DOBJ.Output(display_name="OBJ"), + IO.Image.Output(display_name="texture_image"), + IO.Image.Output(display_name="optional_metallic"), + IO.Image.Output(display_name="optional_normal"), + IO.Image.Output(display_name="optional_roughness"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["generate_type", "generate_type.pbr", "face_count"], + inputs=["image_left", "image_right", "image_back"], + ), + expr=""" + ( + $base := widgets.generate_type = "normal" ? 25 : widgets.generate_type = "lowpoly" ? 30 : 15; + $multiview := ( + inputs.image_left.connected or inputs.image_right.connected or inputs.image_back.connected + ) ? 10 : 0; + $pbr := $lookup(widgets, "generate_type.pbr") ? 10 : 0; + $face := widgets.face_count != 500000 ? 10 : 0; + {"type":"usd","usd": ($base + $multiview + $pbr + $face) * 0.02} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + face_count: int, + generate_type: InputGenerateType, + seed: int, + image_left: Input.Image | None = None, + image_right: Input.Image | None = None, + image_back: Input.Image | None = None, + ) -> IO.NodeOutput: + _ = seed + if model == "3.1" and generate_type["generate_type"].lower() == "lowpoly": + raise ValueError("The LowPoly option is currently unavailable for the 3.1 model.") + validate_image_dimensions(image, min_width=128, min_height=128) + multiview_images = [] + for k, v in { + "left": image_left, + "right": image_right, + "back": image_back, + }.items(): + if v is None: + continue + validate_image_dimensions(v, min_width=128, min_height=128) + multiview_images.append( + Hunyuan3DViewImage( + ViewType=k, + ViewImageUrl=await upload_image_to_comfyapi( + cls, + downscale_image_tensor_by_max_side(v, max_side=4900), + mime_type="image/webp", + total_pixels=24_010_000, + ), + ) + ) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DProTaskRequest( + Model=model, + FaceCount=face_count, + GenerateType=generate_type["generate_type"], + ImageUrl=await upload_image_to_comfyapi( + cls, + downscale_image_tensor_by_max_side(image, max_side=4900), + mime_type="image/webp", + total_pixels=24_010_000, + ), + MultiViewImages=multiview_images if multiview_images else None, + EnablePBR=generate_type.get("pbr", None), + PolygonType=generate_type.get("polygon_type", None), + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + task_id = response.JobId + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-pro/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=task_id), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) + return IO.NodeOutput( + f"{task_id}.glb", + await download_url_to_file_3d( + get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id + ), + obj_result.obj, + obj_result.texture, + obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3), + obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3), + obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3), + ) + + +class TencentModelTo3DUVNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentModelTo3DUVNode", + display_name="Hunyuan3D: Model to UV", + category="api node/3d/Tencent", + description="Perform UV unfolding on a 3D model to generate UV texture. " + "Input model must have less than 30000 faces.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny], + tooltip="Input 3D model (GLB, OBJ, or FBX)", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DOBJ.Output(display_name="OBJ"), + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'), + ) + + SUPPORTED_FORMATS = {"glb", "obj", "fbx"} + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format not in cls.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported file format: '{file_format}'. " + f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}." + ) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DUVTaskRequest( + File=TaskFile3DInput( + Type=file_format.upper(), + Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format), + ) + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), + ) + + +class Tencent3DTextureEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Tencent3DTextureEditNode", + display_name="Hunyuan3D: 3D Texture Edit", + category="api node/3d/Tencent", + description="After inputting the 3D model, perform 3D model texture redrawing.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DFBX, IO.File3DAny], + tooltip="3D model in FBX format. Model should have less than 100000 faces.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DOBJ.Output(display_name="OBJ"), + IO.Image.Output(display_name="texture_image"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.6}""", + ), + ) + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format != "fbx": + raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.") + validate_string(prompt, field_name="prompt", min_length=1, max_length=1024) + model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"), + response_model=To3DProTaskCreateResponse, + data=TextureEditTaskRequest( + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), + Prompt=prompt, + EnablePBR=True, + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"), + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), + await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url), + ) + + +class Tencent3DPartNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Tencent3DPartNode", + display_name="Hunyuan3D: 3D Part", + category="api node/3d/Tencent", + description="Automatically perform component identification and generation based on the model structure.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DFBX, IO.File3DAny], + tooltip="3D model in FBX format. Model should have less than 30000 faces.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'), + ) + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format != "fbx": + raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.") + model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"), + response_model=To3DProTaskCreateResponse, + data=To3DPartTaskRequest( + File=TaskFile3DInput(Type=file_format.upper(), Url=model_url), + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), + ) + + +class TencentSmartTopologyNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TencentSmartTopologyNode", + display_name="Hunyuan3D: Smart Topology", + category="api node/3d/Tencent", + description="Perform smart retopology on a 3D model. " + "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DAny], + tooltip="Input 3D model (GLB or OBJ)", + ), + IO.Combo.Input( + "polygon_type", + options=["triangle", "quadrilateral"], + tooltip="Surface composition type.", + ), + IO.Combo.Input( + "face_level", + options=["medium", "high", "low"], + tooltip="Polygon reduction level.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.File3DOBJ.Output(display_name="OBJ"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'), + ) + + SUPPORTED_FORMATS = {"glb", "obj"} + + @classmethod + async def execute( + cls, + model_3d: Types.File3D, + polygon_type: str, + face_level: str, + seed: int, + ) -> IO.NodeOutput: + _ = seed + file_format = model_3d.format.lower() + if file_format not in cls.SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}." + ) + model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"), + response_model=To3DProTaskCreateResponse, + data=SmartTopologyRequest( + File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url), + PolygonType=polygon_type, + FaceLevel=face_level, + ), + is_rate_limited=_is_tencent_rate_limited, + ) + if response.Error: + raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}") + result = await poll_op( + cls, + ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"), + data=To3DProTaskQueryRequest(JobId=response.JobId), + response_model=To3DProTaskResultResponse, + status_extractor=lambda r: r.Status, + ) + return IO.NodeOutput( + await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), + ) + + +class TencentHunyuan3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TencentTextToModelNode, + TencentImageToModelNode, + TencentModelTo3DUVNode, + Tencent3DTextureEditNode, + Tencent3DPartNode, + TencentSmartTopologyNode, + ] + + +async def comfy_entrypoint() -> TencentHunyuan3DExtension: + return TencentHunyuan3DExtension() diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 48f94e612..97c3609bd 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -4,7 +4,7 @@ from comfy_api.latest import IO, ComfyExtension from PIL import Image import numpy as np import torch -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.ideogram import ( IdeogramGenerateRequest, IdeogramGenerateResponse, ImageRequest, @@ -236,7 +236,6 @@ class IdeogramV1(IO.ComfyNode): display_name="Ideogram V1", category="api node/image/Ideogram", description="Generates images using the Ideogram V1 model.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -262,6 +261,7 @@ class IdeogramV1(IO.ComfyNode): default="AUTO", tooltip="Determine if MagicPrompt should be used in generation", optional=True, + advanced=True, ), IO.Int.Input( "seed", @@ -298,6 +298,17 @@ class IdeogramV1(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), + expr=""" + ( + $n := widgets.num_images; + $base := (widgets.turbo = true) ? 0.0286 : 0.0858; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod @@ -351,7 +362,6 @@ class IdeogramV2(IO.ComfyNode): display_name="Ideogram V2", category="api node/image/Ideogram", description="Generates images using the Ideogram V2 model.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -385,6 +395,7 @@ class IdeogramV2(IO.ComfyNode): default="AUTO", tooltip="Determine if MagicPrompt should be used in generation", optional=True, + advanced=True, ), IO.Int.Input( "seed", @@ -402,6 +413,7 @@ class IdeogramV2(IO.ComfyNode): default="NONE", tooltip="Style type for generation (V2 only)", optional=True, + advanced=True, ), IO.String.Input( "negative_prompt", @@ -436,6 +448,17 @@ class IdeogramV2(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["num_images", "turbo"]), + expr=""" + ( + $n := widgets.num_images; + $base := (widgets.turbo = true) ? 0.0715 : 0.1144; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod @@ -506,7 +529,6 @@ class IdeogramV3(IO.ComfyNode): category="api node/image/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", - is_api_node=True, inputs=[ IO.String.Input( "prompt", @@ -545,6 +567,7 @@ class IdeogramV3(IO.ComfyNode): default="AUTO", tooltip="Determine if MagicPrompt should be used in generation", optional=True, + advanced=True, ), IO.Int.Input( "seed", @@ -571,6 +594,7 @@ class IdeogramV3(IO.ComfyNode): default="DEFAULT", tooltip="Controls the trade-off between generation speed and quality", optional=True, + advanced=True, ), IO.Image.Input( "character_image", @@ -591,6 +615,23 @@ class IdeogramV3(IO.ComfyNode): IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed", "num_images"], inputs=["character_image"]), + expr=""" + ( + $n := widgets.num_images; + $speed := widgets.rendering_speed; + $hasChar := inputs.character_image.connected; + $base := + $contains($speed,"quality") ? ($hasChar ? 0.286 : 0.1287) : + $contains($speed,"default") ? ($hasChar ? 0.2145 : 0.0858) : + $contains($speed,"turbo") ? ($hasChar ? 0.143 : 0.0429) : + 0.0858; + {"type":"usd","usd": $round($base * $n, 2)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 1a6364fa0..9a37ccc53 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -38,7 +38,6 @@ from comfy_api_nodes.apis import ( KlingImageGenerationsRequest, KlingImageGenerationsResponse, KlingImageGenImageReferenceType, - KlingImageGenModelName, KlingImageGenAspectRatio, KlingVideoEffectsRequest, KlingVideoEffectsResponse, @@ -49,8 +48,11 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.kling_api import ( +from comfy_api_nodes.apis.kling import ( ImageToVideoWithAudioRequest, + KlingAvatarRequest, + MotionControlRequest, + MultiPromptEntry, OmniImageParamImage, OmniParamImage, OmniParamVideo, @@ -70,8 +72,10 @@ from comfy_api_nodes.util import ( sync_op, tensor_to_base64_string, upload_audio_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, upload_video_to_comfyapi, + validate_audio_duration, validate_image_aspect_ratio, validate_image_dimensions, validate_string, @@ -79,6 +83,31 @@ from comfy_api_nodes.util import ( validate_video_duration, ) + +def _generate_storyboard_inputs(count: int) -> list: + inputs = [] + for i in range(1, count + 1): + inputs.extend( + [ + IO.String.Input( + f"storyboard_{i}_prompt", + multiline=True, + default="", + tooltip=f"Prompt for storyboard segment {i}. Max 512 characters.", + ), + IO.Int.Input( + f"storyboard_{i}_duration", + default=4, + min=1, + max=15, + display_mode=IO.NumberDisplay.slider, + tooltip=f"Duration for storyboard segment {i} in seconds.", + ), + ] + ) + return inputs + + KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" PATH_IMAGE_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/image2video" @@ -248,7 +277,6 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), - max_poll_attempts=160, ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -566,7 +594,7 @@ async def execute_lipsync( # Upload the audio file to Comfy API and get download URL if audio: audio_url = await upload_audio_to_comfyapi( - cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3" + cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg" ) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: @@ -763,6 +791,33 @@ class KlingTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $m := widgets.mode; + $contains($m,"v2-5-turbo") + ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : $contains($m,"v2-1-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v2-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v1-6") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($m,"v1") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -793,19 +848,48 @@ class OmniProTextToVideoNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProTextToVideoNode", - display_name="Kling Omni Text to Video (Pro)", + display_name="Kling 3.0 Omni Text to Video", category="api node/video/Kling", description="Use text prompts to generate videos with the latest Kling model.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, tooltip="A text prompt describing the video content. " - "This can include both positive and negative descriptions.", + "This can include both positive and negative descriptions. " + "Ignored when storyboards are enabled.", ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), - IO.Combo.Input("duration", options=[5, 10]), + IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.DynamicCombo.Input( + "storyboards", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)), + IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)), + IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)), + IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)), + IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)), + IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)), + ], + tooltip="Generate a series of video segments with individual prompts and durations. " + "Ignored for o1 model.", + optional=True, + ), + IO.Boolean.Input("generate_audio", default=False, optional=True), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -816,6 +900,20 @@ class OmniProTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $isV3 := $contains(widgets.model_name, "v3"); + $audio := $isV3 and widgets.generate_audio; + $rates := $audio + ? {"std": 0.112, "pro": 0.14} + : {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -825,8 +923,46 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt: str, aspect_ratio: str, duration: int, + resolution: str = "1080p", + storyboards: dict | None = None, + generate_audio: bool = False, + seed: int = 0, ) -> IO.NodeOutput: - validate_string(prompt, min_length=1, max_length=2500) + _ = seed + if model_name == "kling-video-o1": + if duration not in (5, 10): + raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.") + if generate_audio: + raise ValueError("kling-video-o1 does not support audio generation.") + stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" + if stories_enabled and model_name == "kling-video-o1": + raise ValueError("kling-video-o1 does not support storyboards.") + validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500) + + multi_shot = None + multi_prompt_list = None + if stories_enabled: + count = int(storyboards["storyboards"].split()[0]) + multi_shot = True + multi_prompt_list = [] + for i in range(1, count + 1): + sb_prompt = storyboards[f"storyboard_{i}_prompt"] + sb_duration = storyboards[f"storyboard_{i}_duration"] + validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512) + multi_prompt_list.append( + MultiPromptEntry( + index=i, + prompt=sb_prompt, + duration=str(sb_duration), + ) + ) + total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list) + if total_storyboard_duration != duration: + raise ValueError( + f"Total storyboard duration ({total_storyboard_duration}s) " + f"must equal the global duration ({duration}s)." + ) + response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -836,6 +972,11 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt=prompt, aspect_ratio=aspect_ratio, duration=str(duration), + mode="pro" if resolution == "1080p" else "std", + multi_shot=multi_shot, + multi_prompt=multi_prompt_list, + shot_type="customize" if multi_shot else None, + sound="on" if generate_audio else "off", ), ) return await finish_omni_video_task(cls, response) @@ -847,30 +988,65 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProFirstLastFrameNode", - display_name="Kling Omni First-Last-Frame to Video (Pro)", + display_name="Kling 3.0 Omni First-Last-Frame to Video", category="api node/video/Kling", description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, tooltip="A text prompt describing the video content. " - "This can include both positive and negative descriptions.", + "This can include both positive and negative descriptions. " + "Ignored when storyboards are enabled.", ), - IO.Combo.Input("duration", options=["5", "10"]), + IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), IO.Image.Input("first_frame"), IO.Image.Input( "end_frame", optional=True, tooltip="An optional end frame for the video. " - "This cannot be used simultaneously with 'reference_images'.", + "This cannot be used simultaneously with 'reference_images'. " + "Does not work with storyboards.", ), IO.Image.Input( "reference_images", optional=True, tooltip="Up to 6 additional reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.DynamicCombo.Input( + "storyboards", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)), + IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)), + IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)), + IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)), + IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)), + IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)), + ], + tooltip="Generate a series of video segments with individual prompts and durations. " + "Only supported for kling-v3-omni.", + optional=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="Generate audio for the video. Only supported for kling-v3-omni.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -881,6 +1057,20 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $isV3 := $contains(widgets.model_name, "v3"); + $audio := $isV3 and widgets.generate_audio; + $rates := $audio + ? {"std": 0.112, "pro": 0.14} + : {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -892,11 +1082,60 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): first_frame: Input.Image, end_frame: Input.Image | None = None, reference_images: Input.Image | None = None, + resolution: str = "1080p", + storyboards: dict | None = None, + generate_audio: bool = False, + seed: int = 0, ) -> IO.NodeOutput: + _ = seed + if model_name == "kling-video-o1": + if duration > 10: + raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") + if generate_audio: + raise ValueError("kling-video-o1 does not support audio generation.") + stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" + if stories_enabled and model_name == "kling-video-o1": + raise ValueError("kling-video-o1 does not support storyboards.") prompt = normalize_omni_prompt_references(prompt) - validate_string(prompt, min_length=1, max_length=2500) + validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500) if end_frame is not None and reference_images is not None: raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + if end_frame is not None and stories_enabled: + raise ValueError("The 'end_frame' input cannot be used simultaneously with storyboards.") + if ( + model_name == "kling-video-o1" + and duration not in (5, 10) + and end_frame is None + and reference_images is None + ): + raise ValueError( + "Duration is only supported for 5 or 10 seconds if there is no end frame or reference images." + ) + + multi_shot = None + multi_prompt_list = None + if stories_enabled: + count = int(storyboards["storyboards"].split()[0]) + multi_shot = True + multi_prompt_list = [] + for i in range(1, count + 1): + sb_prompt = storyboards[f"storyboard_{i}_prompt"] + sb_duration = storyboards[f"storyboard_{i}_duration"] + validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512) + multi_prompt_list.append( + MultiPromptEntry( + index=i, + prompt=sb_prompt, + duration=str(sb_duration), + ) + ) + total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list) + if total_storyboard_duration != duration: + raise ValueError( + f"Total storyboard duration ({total_storyboard_duration}s) " + f"must equal the global duration ({duration}s)." + ) + validate_image_dimensions(first_frame, min_width=300, min_height=300) validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) image_list: list[OmniParamImage] = [ @@ -931,6 +1170,11 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): prompt=prompt, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", + sound="on" if generate_audio else "off", + multi_shot=multi_shot, + multi_prompt=multi_prompt_list, + shot_type="customize" if multi_shot else None, ), ) return await finish_omni_video_task(cls, response) @@ -942,23 +1186,57 @@ class OmniProImageToVideoNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProImageToVideoNode", - display_name="Kling Omni Image to Video (Pro)", + display_name="Kling 3.0 Omni Image to Video", category="api node/video/Kling", description="Use up to 7 reference images to generate a video with the latest Kling model.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, tooltip="A text prompt describing the video content. " - "This can include both positive and negative descriptions.", + "This can include both positive and negative descriptions. " + "Ignored when storyboards are enabled.", ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), - IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), IO.Image.Input( "reference_images", tooltip="Up to 7 reference images.", ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.DynamicCombo.Input( + "storyboards", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)), + IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)), + IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)), + IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)), + IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)), + IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)), + ], + tooltip="Generate a series of video segments with individual prompts and durations. " + "Only supported for kling-v3-omni.", + optional=True, + ), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="Generate audio for the video. Only supported for kling-v3-omni.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -969,6 +1247,20 @@ class OmniProImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $isV3 := $contains(widgets.model_name, "v3"); + $audio := $isV3 and widgets.generate_audio; + $rates := $audio + ? {"std": 0.112, "pro": 0.14} + : {"std": 0.084, "pro": 0.112}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -979,9 +1271,47 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio: str, duration: int, reference_images: Input.Image, + resolution: str = "1080p", + storyboards: dict | None = None, + generate_audio: bool = False, + seed: int = 0, ) -> IO.NodeOutput: + _ = seed + if model_name == "kling-video-o1": + if duration > 10: + raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") + if generate_audio: + raise ValueError("kling-video-o1 does not support audio generation.") + stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" + if stories_enabled and model_name == "kling-video-o1": + raise ValueError("kling-video-o1 does not support storyboards.") prompt = normalize_omni_prompt_references(prompt) - validate_string(prompt, min_length=1, max_length=2500) + validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500) + + multi_shot = None + multi_prompt_list = None + if stories_enabled: + count = int(storyboards["storyboards"].split()[0]) + multi_shot = True + multi_prompt_list = [] + for i in range(1, count + 1): + sb_prompt = storyboards[f"storyboard_{i}_prompt"] + sb_duration = storyboards[f"storyboard_{i}_duration"] + validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512) + multi_prompt_list.append( + MultiPromptEntry( + index=i, + prompt=sb_prompt, + duration=str(sb_duration), + ) + ) + total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list) + if total_storyboard_duration != duration: + raise ValueError( + f"Total storyboard duration ({total_storyboard_duration}s) " + f"must equal the global duration ({duration}s)." + ) + if get_number_of_images(reference_images) > 7: raise ValueError("The maximum number of reference images is 7.") for i in reference_images: @@ -1000,6 +1330,11 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio=aspect_ratio, duration=str(duration), image_list=image_list, + mode="pro" if resolution == "1080p" else "std", + sound="on" if generate_audio else "off", + multi_shot=multi_shot, + multi_prompt=multi_prompt_list, + shot_type="customize" if multi_shot else None, ), ) return await finish_omni_video_task(cls, response) @@ -1011,11 +1346,11 @@ class OmniProVideoToVideoNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProVideoToVideoNode", - display_name="Kling Omni Video to Video (Pro)", + display_name="Kling 3.0 Omni Video to Video", category="api node/video/Kling", description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, @@ -1031,6 +1366,18 @@ class OmniProVideoToVideoNode(IO.ComfyNode): tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -1041,6 +1388,16 @@ class OmniProVideoToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.126, "pro": 0.168}; + {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} + ) + """, + ), ) @classmethod @@ -1053,7 +1410,10 @@ class OmniProVideoToVideoNode(IO.ComfyNode): reference_video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", + seed: int = 0, ) -> IO.NodeOutput: + _ = seed prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) @@ -1085,6 +1445,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): duration=str(duration), image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1096,11 +1457,12 @@ class OmniProEditVideoNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProEditVideoNode", - display_name="Kling Omni Edit Video (Pro)", + display_name="Kling 3.0 Omni Edit Video", category="api node/video/Kling", + essentials_category="Video Generation", description="Edit an existing video with the latest model from Kling.", inputs=[ - IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.String.Input( "prompt", multiline=True, @@ -1114,6 +1476,18 @@ class OmniProEditVideoNode(IO.ComfyNode): tooltip="Up to 4 additional reference images.", optional=True, ), + IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Video.Output(), @@ -1124,6 +1498,16 @@ class OmniProEditVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution"]), + expr=""" + ( + $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $rates := {"std": 0.126, "pro": 0.168}; + {"type":"usd","usd": $lookup($rates, $mode), "format":{"suffix":"/second"}} + ) + """, + ), ) @classmethod @@ -1134,7 +1518,10 @@ class OmniProEditVideoNode(IO.ComfyNode): video: Input.Video, keep_original_sound: bool, reference_images: Input.Image | None = None, + resolution: str = "1080p", + seed: int = 0, ) -> IO.NodeOutput: + _ = seed prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) validate_video_duration(video, min_duration=3.0, max_duration=10.05) @@ -1166,6 +1553,7 @@ class OmniProEditVideoNode(IO.ComfyNode): duration=None, image_list=image_list if image_list else None, video_list=video_list, + mode="pro" if resolution == "1080p" else "std", ), ) return await finish_omni_video_task(cls, response) @@ -1177,27 +1565,43 @@ class OmniProImageNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingOmniProImageNode", - display_name="Kling Omni Image (Pro)", + display_name="Kling 3.0 Omni Image", category="api node/image/Kling", description="Create or edit images with the latest model from Kling.", inputs=[ - IO.Combo.Input("model_name", options=["kling-image-o1"]), + IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]), IO.String.Input( "prompt", multiline=True, tooltip="A text prompt describing the image content. " "This can include both positive and negative descriptions.", ), - IO.Combo.Input("resolution", options=["1K", "2K"]), + IO.Combo.Input("resolution", options=["1K", "2K", "4K"]), IO.Combo.Input( "aspect_ratio", options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"], ), + IO.Combo.Input( + "series_amount", + options=["disabled", "2", "3", "4", "5", "6", "7", "8", "9"], + tooltip="Generate a series of images. Not supported for kling-image-o1.", + ), IO.Image.Input( "reference_images", tooltip="Up to 10 additional reference images.", optional=True, ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Image.Output(), @@ -1208,6 +1612,18 @@ class OmniProImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution", "series_amount", "model_name"]), + expr=""" + ( + $prices := {"1k": 0.028, "2k": 0.028, "4k": 0.056}; + $base := $lookup($prices, widgets.resolution); + $isO1 := widgets.model_name = "kling-image-o1"; + $mult := ($isO1 or widgets.series_amount = "disabled") ? 1 : $number(widgets.series_amount); + {"type":"usd","usd": $base * $mult} + ) + """, + ), ) @classmethod @@ -1217,8 +1633,13 @@ class OmniProImageNode(IO.ComfyNode): prompt: str, resolution: str, aspect_ratio: str, + series_amount: str = "disabled", reference_images: Input.Image | None = None, + seed: int = 0, ) -> IO.NodeOutput: + _ = seed + if model_name == "kling-image-o1" and resolution == "4K": + raise ValueError("4K resolution is not supported for kling-image-o1 model.") prompt = normalize_omni_prompt_references(prompt) validate_string(prompt, min_length=1, max_length=2500) image_list: list[OmniImageParamImage] = [] @@ -1230,6 +1651,9 @@ class OmniProImageNode(IO.ComfyNode): validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): image_list.append(OmniImageParamImage(image=i)) + use_series = series_amount != "disabled" + if use_series and model_name == "kling-image-o1": + raise ValueError("kling-image-o1 does not support series generation.") response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), @@ -1240,6 +1664,8 @@ class OmniProImageNode(IO.ComfyNode): resolution=resolution.lower(), aspect_ratio=aspect_ratio, image_list=image_list if image_list else None, + result_type="series" if use_series else None, + series_amount=int(series_amount) if use_series else None, ), ) if response.code: @@ -1252,7 +1678,9 @@ class OmniProImageNode(IO.ComfyNode): response_model=TaskStatusResponse, status_extractor=lambda r: (r.data.task_status if r.data else None), ) - return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) + images = final_response.data.task_result.series_images or final_response.data.task_result.images + tensors = [await download_url_to_image_tensor(img.url) for img in images] + return IO.NodeOutput(torch.cat(tensors, dim=0)) class KlingCameraControlT2VNode(IO.ComfyNode): @@ -1293,6 +1721,9 @@ class KlingCameraControlT2VNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14}""", + ), ) @classmethod @@ -1355,6 +1786,33 @@ class KlingImage2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), + expr=""" + ( + $mode := widgets.mode; + $model := widgets.model_name; + $dur := widgets.duration; + $contains($model,"v2-5-turbo") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : ($contains($model,"v2-1-master") or $contains($model,"v2-master")) + ? ($contains($dur,"10") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : ($contains($model,"v2-1") or $contains($model,"v1-6") or $contains($model,"v1-5")) + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($model,"v1") + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1428,6 +1886,9 @@ class KlingCameraControlI2VNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.49}""", + ), ) @classmethod @@ -1498,6 +1959,33 @@ class KlingStartEndFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $m := widgets.mode; + $contains($m,"v2-5-turbo") + ? ($contains($m,"10") ? {"type":"usd","usd":0.7} : {"type":"usd","usd":0.35}) + : $contains($m,"v2-1") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : $contains($m,"v2-master") + ? ($contains($m,"10s") ? {"type":"usd","usd":2.8} : {"type":"usd","usd":1.4}) + : $contains($m,"v1-6") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($m,"v1") + ? ( + $contains($m,"pro") + ? ($contains($m,"10s") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($m,"10s") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1563,6 +2051,9 @@ class KlingVideoExtendNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.28}""", + ), ) @classmethod @@ -1644,6 +2135,29 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "model_name", "duration"]), + expr=""" + ( + $mode := widgets.mode; + $model := widgets.model_name; + $dur := widgets.duration; + ($contains($model,"v1-6") or $contains($model,"v1-5")) + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.56} : {"type":"usd","usd":0.28}) + ) + : $contains($model,"v1") + ? ( + $contains($mode,"pro") + ? ($contains($dur,"10") ? {"type":"usd","usd":0.98} : {"type":"usd","usd":0.49}) + : ($contains($dur,"10") ? {"type":"usd","usd":0.28} : {"type":"usd","usd":0.14}) + ) + : {"type":"usd","usd":0.14} + ) + """, + ), ) @classmethod @@ -1708,6 +2222,16 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["effect_scene"]), + expr=""" + ( + ($contains(widgets.effect_scene,"dizzydizzy") or $contains(widgets.effect_scene,"bloombloom")) + ? {"type":"usd","usd":0.49} + : {"type":"usd","usd":0.28} + ) + """, + ), ) @classmethod @@ -1741,6 +2265,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): node_id="KlingLipSyncAudioToVideoNode", display_name="Kling Lip Sync Video with Audio", category="api node/video/Kling", + essentials_category="Video Generation", description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", inputs=[ IO.Video.Input("video"), @@ -1762,6 +2287,9 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", + ), ) @classmethod @@ -1809,6 +2337,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): max=2.0, display_mode=IO.NumberDisplay.slider, tooltip="Speech Rate. Valid range: 0.8~2.0, accurate to one decimal place.", + advanced=True, ), ], outputs=[ @@ -1822,6 +2351,9 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1,"format":{"approximate":true}}""", + ), ) @classmethod @@ -1872,6 +2404,9 @@ class KlingVirtualTryOnNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.7}""", + ), ) @classmethod @@ -1915,7 +2450,7 @@ class KlingImageGenerationNode(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingImageGenerationNode", - display_name="Kling Image Generation", + display_name="Kling 3.0 Image", category="api node/image/Kling", description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.", inputs=[ @@ -1924,6 +2459,7 @@ class KlingImageGenerationNode(IO.ComfyNode): IO.Combo.Input( "image_type", options=[i.value for i in KlingImageGenImageReferenceType], + advanced=True, ), IO.Float.Input( "image_fidelity", @@ -1933,6 +2469,7 @@ class KlingImageGenerationNode(IO.ComfyNode): step=0.01, display_mode=IO.NumberDisplay.slider, tooltip="Reference intensity for user-uploaded images", + advanced=True, ), IO.Float.Input( "human_fidelity", @@ -1942,12 +2479,9 @@ class KlingImageGenerationNode(IO.ComfyNode): step=0.01, display_mode=IO.NumberDisplay.slider, tooltip="Subject reference similarity", + advanced=True, ), - IO.Combo.Input( - "model_name", - options=[i.value for i in KlingImageGenModelName], - default="kling-v2", - ), + IO.Combo.Input("model_name", options=["kling-v3", "kling-v2", "kling-v1-5"]), IO.Combo.Input( "aspect_ratio", options=[i.value for i in KlingImageGenAspectRatio], @@ -1961,6 +2495,17 @@ class KlingImageGenerationNode(IO.ComfyNode): tooltip="Number of generated images", ), IO.Image.Input("image", optional=True), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + optional=True, + ), ], outputs=[ IO.Image.Output(), @@ -1971,12 +2516,25 @@ class KlingImageGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model_name", "n"], inputs=["image"]), + expr=""" + ( + $m := widgets.model_name; + $base := + $contains($m,"kling-v1-5") + ? (inputs.image.connected ? 0.028 : 0.014) + : $contains($m,"kling-v3") ? 0.028 : 0.014; + {"type":"usd","usd": $base * widgets.n} + ) + """, + ), ) @classmethod async def execute( cls, - model_name: KlingImageGenModelName, + model_name: str, prompt: str, negative_prompt: str, image_type: KlingImageGenImageReferenceType, @@ -1985,17 +2543,11 @@ class KlingImageGenerationNode(IO.ComfyNode): n: int, aspect_ratio: KlingImageGenAspectRatio, image: torch.Tensor | None = None, + seed: int = 0, ) -> IO.NodeOutput: + _ = seed validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) - - if image is None: - image_type = None - elif model_name == KlingImageGenModelName.kling_v1: - raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.") - else: - image = tensor_to_base64_string(image) - task_creation_response = await sync_op( cls, ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"), @@ -2004,8 +2556,8 @@ class KlingImageGenerationNode(IO.ComfyNode): model_name=model_name, prompt=prompt, negative_prompt=negative_prompt, - image=image, - image_reference=image_type, + image=tensor_to_base64_string(image) if image is not None else None, + image_reference=image_type if image is not None else None, image_fidelity=image_fidelity, human_fidelity=human_fidelity, n=n, @@ -2035,7 +2587,7 @@ class TextToVideoWithAudio(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingTextToVideoWithAudio", - display_name="Kling Text to Video with Audio", + display_name="Kling 2.6 Text to Video with Audio", category="api node/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), @@ -2043,7 +2595,7 @@ class TextToVideoWithAudio(IO.ComfyNode): IO.Combo.Input("mode", options=["pro"]), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input("duration", options=[5, 10]), - IO.Boolean.Input("generate_audio", default=True), + IO.Boolean.Input("generate_audio", default=True, advanced=True), ], outputs=[ IO.Video.Output(), @@ -2054,6 +2606,10 @@ class TextToVideoWithAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), + expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", + ), ) @classmethod @@ -2099,7 +2655,7 @@ class ImageToVideoWithAudio(IO.ComfyNode): def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="KlingImageToVideoWithAudio", - display_name="Kling Image(First Frame) to Video with Audio", + display_name="Kling 2.6 Image(First Frame) to Video with Audio", category="api node/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), @@ -2107,7 +2663,7 @@ class ImageToVideoWithAudio(IO.ComfyNode): IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), IO.Combo.Input("mode", options=["pro"]), IO.Combo.Input("duration", options=[5, 10]), - IO.Boolean.Input("generate_audio", default=True), + IO.Boolean.Input("generate_audio", default=True, advanced=True), ], outputs=[ IO.Video.Output(), @@ -2118,6 +2674,10 @@ class ImageToVideoWithAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "generate_audio"]), + expr="""{"type":"usd","usd": 0.07 * widgets.duration * (widgets.generate_audio ? 2 : 1)}""", + ), ) @classmethod @@ -2159,6 +2719,529 @@ class ImageToVideoWithAudio(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) +class MotionControl(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingMotionControl", + display_name="Kling Motion Control", + category="api node/video/Kling", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.Image.Input("reference_image"), + IO.Video.Input( + "reference_video", + tooltip="Motion reference video used to drive movement/expression.\n" + "Duration limits depend on character_orientation:\n" + " - image: 3–10s (max 10s)\n" + " - video: 3–30s (max 30s)", + ), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Combo.Input( + "character_orientation", + options=["video", "image"], + tooltip="Controls where the character's facing/orientation comes from.\n" + "video: movements, expressions, camera moves, and orientation " + "follow the motion reference video (other details via prompt).\n" + "image: movements and expressions still follow the motion reference video, " + "but the character orientation matches the reference image (camera/other details via prompt).", + ), + IO.Combo.Input("mode", options=["pro", "std"]), + IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $prices := {"std": 0.07, "pro": 0.112}; + {"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + reference_image: Input.Image, + reference_video: Input.Video, + keep_original_sound: bool, + character_orientation: str, + mode: str, + model: str = "kling-v2-6", + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2500) + validate_image_dimensions(reference_image, min_width=340, min_height=340) + validate_image_aspect_ratio(reference_image, (1, 2.5), (2.5, 1)) + if character_orientation == "image": + validate_video_duration(reference_video, min_duration=3, max_duration=10) + else: + validate_video_duration(reference_video, min_duration=3, max_duration=30) + validate_video_dimensions(reference_video, min_width=340, min_height=340, max_width=3850, max_height=3850) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/motion-control", method="POST"), + response_model=TaskStatusResponse, + data=MotionControlRequest( + prompt=prompt, + image_url=(await upload_images_to_comfyapi(cls, reference_image))[0], + video_url=await upload_video_to_comfyapi(cls, reference_video), + keep_original_sound="yes" if keep_original_sound else "no", + character_orientation=character_orientation, + mode=mode, + model_name=model, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/motion-control/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + +class KlingVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingVideoNode", + display_name="Kling 3.0 Video", + category="api node/video/Kling", + description="Generate videos with Kling V3. " + "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.", + inputs=[ + IO.DynamicCombo.Input( + "multi_shot", + options=[ + IO.DynamicCombo.Option( + "disabled", + [ + IO.String.Input("prompt", multiline=True, default=""), + IO.String.Input("negative_prompt", multiline=True, default=""), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)), + IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)), + IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)), + IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)), + IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)), + IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)), + ], + tooltip="Generate a series of video segments with individual prompts and durations.", + ), + IO.Boolean.Input("generate_audio", default=True), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "kling-v3", + [ + IO.Combo.Input("resolution", options=["1080p", "720p"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "1:1"], + tooltip="Ignored in image-to-video mode.", + ), + ], + ), + ], + tooltip="Model and generation settings.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + IO.Image.Input( + "start_frame", + optional=True, + tooltip="Optional start frame image. When connected, switches to image-to-video mode.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model.resolution", + "generate_audio", + "multi_shot", + "multi_shot.duration", + "multi_shot.storyboard_1_duration", + "multi_shot.storyboard_2_duration", + "multi_shot.storyboard_3_duration", + "multi_shot.storyboard_4_duration", + "multi_shot.storyboard_5_duration", + "multi_shot.storyboard_6_duration", + ], + ), + expr=""" + ( + $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $res := $lookup(widgets, "model.resolution"); + $audio := widgets.generate_audio ? "on" : "off"; + $rate := $lookup($lookup($rates, $res), $audio); + $ms := widgets.multi_shot; + $isSb := $ms != "disabled"; + $n := $isSb ? $number($substring($ms, 0, 1)) : 0; + $d1 := $lookup(widgets, "multi_shot.storyboard_1_duration"); + $d2 := $n >= 2 ? $lookup(widgets, "multi_shot.storyboard_2_duration") : 0; + $d3 := $n >= 3 ? $lookup(widgets, "multi_shot.storyboard_3_duration") : 0; + $d4 := $n >= 4 ? $lookup(widgets, "multi_shot.storyboard_4_duration") : 0; + $d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0; + $d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0; + $dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration"); + {"type":"usd","usd": $rate * $dur} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + multi_shot: dict, + generate_audio: bool, + model: dict, + seed: int, + start_frame: Input.Image | None = None, + ) -> IO.NodeOutput: + _ = seed + mode = "pro" if model["resolution"] == "1080p" else "std" + custom_multi_shot = False + if multi_shot["multi_shot"] == "disabled": + shot_type = None + else: + shot_type = "customize" + custom_multi_shot = True + + multi_prompt_list = None + if shot_type == "customize": + count = int(multi_shot["multi_shot"].split()[0]) + multi_prompt_list = [] + for i in range(1, count + 1): + sb_prompt = multi_shot[f"storyboard_{i}_prompt"] + sb_duration = multi_shot[f"storyboard_{i}_duration"] + validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512) + multi_prompt_list.append( + MultiPromptEntry( + index=i, + prompt=sb_prompt, + duration=str(sb_duration), + ) + ) + duration = sum(int(e.duration) for e in multi_prompt_list) + if duration < 3 or duration > 15: + raise ValueError( + f"Total storyboard duration ({duration}s) must be between 3 and 15 seconds." + ) + else: + duration = multi_shot["duration"] + validate_string(multi_shot["prompt"], min_length=1, max_length=2500) + + if start_frame is not None: + validate_image_dimensions(start_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1)) + image_url = await upload_image_to_comfyapi(cls, start_frame, wait_label="Uploading start frame") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), + response_model=TaskStatusResponse, + data=ImageToVideoWithAudioRequest( + model_name=model["model"], + image=image_url, + prompt=None if custom_multi_shot else multi_shot["prompt"], + negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"], + mode=mode, + duration=str(duration), + sound="on" if generate_audio else "off", + multi_shot=True if shot_type else None, + multi_prompt=multi_prompt_list, + shot_type=shot_type, + ), + ) + poll_path = f"/proxy/kling/v1/videos/image2video/{response.data.task_id}" + else: + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"), + response_model=TaskStatusResponse, + data=TextToVideoWithAudioRequest( + model_name=model["model"], + aspect_ratio=model["aspect_ratio"], + prompt=None if custom_multi_shot else multi_shot["prompt"], + negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"], + mode=mode, + duration=str(duration), + sound="on" if generate_audio else "off", + multi_shot=True if shot_type else None, + multi_prompt=multi_prompt_list, + shot_type=shot_type, + ), + ) + poll_path = f"/proxy/kling/v1/videos/text2video/{response.data.task_id}" + + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=poll_path), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + +class KlingFirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingFirstLastFrameNode", + display_name="Kling 3.0 First-Last-Frame to Video", + category="api node/video/Kling", + description="Generate videos with Kling V3 using first and last frames.", + inputs=[ + IO.String.Input("prompt", multiline=True, default=""), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + display_mode=IO.NumberDisplay.slider, + ), + IO.Image.Input("first_frame"), + IO.Image.Input("end_frame"), + IO.Boolean.Input("generate_audio", default=True), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "kling-v3", + [ + IO.Combo.Input("resolution", options=["1080p", "720p"]), + ], + ), + ], + tooltip="Model and generation settings.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model.resolution", "generate_audio", "duration"], + ), + expr=""" + ( + $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $res := $lookup(widgets, "model.resolution"); + $audio := widgets.generate_audio ? "on" : "off"; + $rate := $lookup($lookup($rates, $res), $audio); + {"type":"usd","usd": $rate * widgets.duration} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + duration: int, + first_frame: Input.Image, + end_frame: Input.Image, + generate_audio: bool, + model: dict, + seed: int, + ) -> IO.NodeOutput: + _ = seed + validate_string(prompt, min_length=1, max_length=2500) + validate_image_dimensions(first_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) + validate_image_dimensions(end_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) + image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame") + image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), + response_model=TaskStatusResponse, + data=ImageToVideoWithAudioRequest( + model_name=model["model"], + image=image_url, + image_tail=image_tail_url, + prompt=prompt, + mode="pro" if model["resolution"] == "1080p" else "std", + duration=str(duration), + sound="on" if generate_audio else "off", + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + +class KlingAvatarNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingAvatarNode", + display_name="Kling Avatar 2.0", + category="api node/video/Kling", + description="Generate broadcast-style digital human videos from a single photo and an audio file.", + inputs=[ + IO.Image.Input( + "image", + tooltip="Avatar reference image. " + "Width and height must be at least 300px. Aspect ratio must be between 1:2.5 and 2.5:1.", + ), + IO.Audio.Input( + "sound_file", + tooltip="Audio input. Must be between 2 and 300 seconds in duration.", + ), + IO.Combo.Input("mode", options=["std", "pro"]), + IO.String.Input( + "prompt", + multiline=True, + default="", + optional=True, + tooltip="Optional prompt to define avatar actions, emotions, and camera movements.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $prices := {"std": 0.056, "pro": 0.112}; + {"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + sound_file: Input.Audio, + mode: str, + seed: int, + prompt: str = "", + ) -> IO.NodeOutput: + validate_image_dimensions(image, min_width=300, min_height=300) + validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1)) + validate_audio_duration(sound_file, min_duration=2, max_duration=300) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/avatar/image2video", method="POST"), + response_model=TaskStatusResponse, + data=KlingAvatarRequest( + image=await upload_image_to_comfyapi(cls, image), + sound_file=await upload_audio_to_comfyapi( + cls, sound_file, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg" + ), + prompt=prompt or None, + mode=mode, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/avatar/image2video/{response.data.task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + max_poll_attempts=800, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) + + class KlingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -2184,6 +3267,10 @@ class KlingExtension(ComfyExtension): OmniProImageNode, TextToVideoWithAudio, ImageToVideoWithAudio, + MotionControl, + KlingVideoNode, + KlingFirstLastFrameNode, + KlingAvatarNode, ] diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 7e61560dc..0a219af96 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -28,6 +28,22 @@ class ExecuteTaskRequest(BaseModel): image_uri: str | None = Field(None) +PRICE_BADGE = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $prices := { + "ltx-2 (pro)": {"1920x1080":0.06,"2560x1440":0.12,"3840x2160":0.24}, + "ltx-2 (fast)": {"1920x1080":0.04,"2560x1440":0.08,"3840x2160":0.16} + }; + $modelPrices := $lookup($prices, $lowercase(widgets.model)); + $pps := $lookup($modelPrices, widgets.resolution); + {"type":"usd","usd": $pps * widgets.duration} + ) + """, +) + + class TextToVideoNode(IO.ComfyNode): @classmethod def define_schema(cls): @@ -58,6 +74,7 @@ class TextToVideoNode(IO.ComfyNode): default=False, optional=True, tooltip="When true, the generated video will include AI-generated audio matching the scene.", + advanced=True, ), ], outputs=[ @@ -69,6 +86,7 @@ class TextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE, ) @classmethod @@ -134,6 +152,7 @@ class ImageToVideoNode(IO.ComfyNode): default=False, optional=True, tooltip="When true, the generated video will include AI-generated audio matching the scene.", + advanced=True, ), ], outputs=[ @@ -145,6 +164,7 @@ class ImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE, ) @classmethod diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 894f2b08c..9ed6cd299 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.luma_api import ( +from comfy_api_nodes.apis.luma import ( LumaAspectRatio, LumaCharacterRef, LumaConceptChain, @@ -189,6 +189,19 @@ class LumaImageGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m,"photon-flash-1") + ? {"type":"usd","usd":0.0027} + : $contains($m,"photon-1") + ? {"type":"usd","usd":0.0104} + : {"type":"usd","usd":0.0246} + ) + """, + ), ) @classmethod @@ -303,6 +316,19 @@ class LumaImageModifyNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m,"photon-flash-1") + ? {"type":"usd","usd":0.0027} + : $contains($m,"photon-1") + ? {"type":"usd","usd":0.0104} + : {"type":"usd","usd":0.0246} + ) + """, + ), ) @classmethod @@ -395,6 +421,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -505,6 +532,8 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, + ) @classmethod @@ -568,6 +597,53 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): return LumaKeyframes(frame0=frame0, frame1=frame1) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "resolution", "duration"]), + expr=""" + ( + $p := { + "ray-flash-2": { + "5s": {"4k":3.13,"1080p":0.79,"720p":0.34,"540p":0.2}, + "9s": {"4k":5.65,"1080p":1.42,"720p":0.61,"540p":0.36} + }, + "ray-2": { + "5s": {"4k":9.11,"1080p":2.27,"720p":1.02,"540p":0.57}, + "9s": {"4k":16.4,"1080p":4.1,"720p":1.83,"540p":1.03} + } + }; + + $m := widgets.model; + $d := widgets.duration; + $r := widgets.resolution; + + $modelKey := + $contains($m,"ray-flash-2") ? "ray-flash-2" : + $contains($m,"ray-2") ? "ray-2" : + $contains($m,"ray-1-6") ? "ray-1-6" : + "other"; + + $durKey := $contains($d,"5s") ? "5s" : $contains($d,"9s") ? "9s" : ""; + $resKey := + $contains($r,"4k") ? "4k" : + $contains($r,"1080p") ? "1080p" : + $contains($r,"720p") ? "720p" : + $contains($r,"540p") ? "540p" : ""; + + $modelPrices := $lookup($p, $modelKey); + $durPrices := $lookup($modelPrices, $durKey); + $v := $lookup($durPrices, $resKey); + + $price := + ($modelKey = "ray-1-6") ? 0.5 : + ($modelKey = "other") ? 0.79 : + ($exists($v) ? $v : 0.79); + + {"type":"usd","usd": $price} + ) + """, +) + + class LumaExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py new file mode 100644 index 000000000..0f53208d4 --- /dev/null +++ b/comfy_api_nodes/nodes_magnific.py @@ -0,0 +1,945 @@ +import math + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.magnific import ( + ImageRelightAdvancedSettingsRequest, + ImageRelightRequest, + ImageSkinEnhancerCreativeRequest, + ImageSkinEnhancerFaithfulRequest, + ImageSkinEnhancerFlexibleRequest, + ImageStyleTransferRequest, + ImageUpscalerCreativeRequest, + ImageUpscalerPrecisionV2Request, + InputAdvancedSettings, + InputPortraitMode, + InputSkinEnhancerMode, + TaskResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + downscale_image_tensor, + get_image_dimensions, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio, + validate_image_dimensions, +) + +_EUR_TO_USD = 1.19 + + +def _tier_price_eur(megapixels: float) -> float: + """Price in EUR for a single Magnific upscaling step based on input megapixels.""" + if megapixels <= 1.3: + return 0.143 + if megapixels <= 3.0: + return 0.286 + if megapixels <= 6.4: + return 0.429 + return 1.716 + + +def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float: + """Calculate total Magnific upscale price in USD for given input dimensions and scale factor.""" + num_steps = int(math.log2(scale)) + total_eur = 0.0 + pixels = width * height + for _ in range(num_steps): + total_eur += _tier_price_eur(pixels / 1_000_000) + pixels *= 4 + return round(total_eur * _EUR_TO_USD, 2) + + +class MagnificImageUpscalerCreativeNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageUpscalerCreativeNode", + display_name="Magnific Image Upscale (Creative)", + category="api node/image/Magnific", + description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. " + "Maximum output: 25.3 megapixels.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", multiline=True, default=""), + IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), + IO.Combo.Input( + "optimized_for", + options=[ + "standard", + "soft_portraits", + "hard_portraits", + "art_n_illustration", + "videogame_assets", + "nature_n_landscapes", + "films_n_photography", + "3d_renders", + "science_fiction_n_horror", + ], + ), + IO.Int.Input("creativity", min=-10, max=10, default=0, display_mode=IO.NumberDisplay.slider), + IO.Int.Input( + "hdr", + min=-10, + max=10, + default=0, + tooltip="The level of definition and detail.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "resemblance", + min=-10, + max=10, + default=0, + tooltip="The level of resemblance to the original image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "fractality", + min=-10, + max=10, + default=0, + tooltip="The strength of the prompt and intricacy per square pixel.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "engine", + options=["automatic", "magnific_illusio", "magnific_sharpy", "magnific_sparkle"], + advanced=True, + ), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed maximum pixel limit.", + advanced=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]), + expr=""" + ( + $ad := widgets.auto_downscale; + $mins := $ad + ? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515} + : {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844}; + $maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187}; + { + "type": "range_usd", + "min_usd": $lookup($mins, widgets.scale_factor), + "max_usd": $lookup($maxs, widgets.scale_factor), + "format": { "approximate": true } + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + scale_factor: str, + optimized_for: str, + creativity: int, + hdr: int, + resemblance: int, + fractality: int, + engine: str, + auto_downscale: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + max_output_pixels = 25_300_000 + height, width = get_image_dimensions(image) + requested_scale = int(scale_factor.rstrip("x")) + output_pixels = height * width * requested_scale * requested_scale + + if output_pixels > max_output_pixels: + if auto_downscale: + # Find optimal scale factor that doesn't require >2x downscale. + # Server upscales in 2x steps, so aggressive downscaling degrades quality. + input_pixels = width * height + scale = 2 + max_input_pixels = max_output_pixels // 4 + for candidate in [16, 8, 4, 2]: + if candidate > requested_scale: + continue + scale_output_pixels = input_pixels * candidate * candidate + if scale_output_pixels <= max_output_pixels: + scale = candidate + max_input_pixels = None + break + downscale_ratio = math.sqrt(scale_output_pixels / max_output_pixels) + if downscale_ratio <= 2.0: + scale = candidate + max_input_pixels = max_output_pixels // (candidate * candidate) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + scale_factor = f"{scale}x" + else: + raise ValueError( + f"Output size ({width * requested_scale}x{height * requested_scale} = {output_pixels:,} pixels) " + f"exceeds maximum allowed size of {max_output_pixels:,} pixels. " + f"Use a smaller input image or lower scale factor." + ) + + final_height, final_width = get_image_dimensions(image) + actual_scale = int(scale_factor.rstrip("x")) + price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"), + response_model=TaskResponse, + data=ImageUpscalerCreativeRequest( + image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], + scale_factor=scale_factor, + optimized_for=optimized_for, + creativity=creativity, + hdr=hdr, + resemblance=resemblance, + fractality=fractality, + engine=engine, + prompt=prompt if prompt else None, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + price_extractor=lambda _: price_usd, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageUpscalerPreciseV2Node", + display_name="Magnific Image Upscale (Precise V2)", + category="api node/image/Magnific", + description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " + "Maximum output: 10060×10060 pixels.", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("scale_factor", options=["2x", "4x", "8x", "16x"]), + IO.Combo.Input( + "flavor", + options=["sublime", "photo", "photo_denoiser"], + tooltip="Processing style: " + "sublime for general use, photo for photographs, photo_denoiser for noisy photos.", + ), + IO.Int.Input( + "sharpen", + min=0, + max=100, + default=7, + tooltip="Image sharpness intensity. Higher values increase edge definition and clarity.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "smart_grain", + min=0, + max=100, + default=7, + tooltip="Intelligent grain/texture enhancement to prevent the image from " + "looking too smooth or artificial.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "ultra_detail", + min=0, + max=100, + default=30, + tooltip="Controls fine detail, textures, and micro-details added during upscaling.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Boolean.Input( + "auto_downscale", + default=False, + tooltip="Automatically downscale input image if output would exceed maximum resolution.", + advanced=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]), + expr=""" + ( + $mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844}; + $maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06}; + { + "type": "range_usd", + "min_usd": $lookup($mins, widgets.scale_factor), + "max_usd": $lookup($maxs, widgets.scale_factor), + "format": { "approximate": true } + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + scale_factor: str, + flavor: str, + sharpen: int, + smart_grain: int, + ultra_detail: int, + auto_downscale: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + max_output_dimension = 10060 + height, width = get_image_dimensions(image) + requested_scale = int(scale_factor.strip("x")) + output_width = width * requested_scale + output_height = height * requested_scale + + if output_width > max_output_dimension or output_height > max_output_dimension: + if auto_downscale: + # Find optimal scale factor that doesn't require >2x downscale. + # Server upscales in 2x steps, so aggressive downscaling degrades quality. + max_dim = max(width, height) + scale = 2 + max_input_dim = max_output_dimension // 2 + scale_ratio = max_input_dim / max_dim + max_input_pixels = int(width * height * scale_ratio * scale_ratio) + for candidate in [16, 8, 4, 2]: + if candidate > requested_scale: + continue + output_dim = max_dim * candidate + if output_dim <= max_output_dimension: + scale = candidate + max_input_pixels = None + break + downscale_ratio = output_dim / max_output_dimension + if downscale_ratio <= 2.0: + scale = candidate + max_input_dim = max_output_dimension // candidate + scale_ratio = max_input_dim / max_dim + max_input_pixels = int(width * height * scale_ratio * scale_ratio) + break + + if max_input_pixels is not None: + image = downscale_image_tensor(image, total_pixels=max_input_pixels) + requested_scale = scale + else: + raise ValueError( + f"Output dimensions ({output_width}x{output_height}) exceed maximum allowed " + f"resolution of {max_output_dimension}x{max_output_dimension} pixels. " + f"Use a smaller input image or lower scale factor." + ) + + final_height, final_width = get_image_dimensions(image) + price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"), + response_model=TaskResponse, + data=ImageUpscalerPrecisionV2Request( + image=(await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=None))[0], + scale_factor=requested_scale, + flavor=flavor, + sharpen=sharpen, + smart_grain=smart_grain, + ultra_detail=ultra_detail, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + price_extractor=lambda _: price_usd, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageStyleTransferNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageStyleTransferNode", + display_name="Magnific Image Style Transfer", + category="api node/image/Magnific", + description="Transfer the style from a reference image to your input image.", + inputs=[ + IO.Image.Input("image", tooltip="The image to apply style transfer to."), + IO.Image.Input("reference_image", tooltip="The reference image to extract style from."), + IO.String.Input("prompt", multiline=True, default=""), + IO.Int.Input( + "style_strength", + min=0, + max=100, + default=100, + tooltip="Percentage of style strength.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "structure_strength", + min=0, + max=100, + default=50, + tooltip="Maintains the structure of the original image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "flavor", + options=["faithful", "gen_z", "psychedelia", "detaily", "clear", "donotstyle", "donotstyle_sharp"], + tooltip="Style transfer flavor.", + ), + IO.Combo.Input( + "engine", + options=[ + "balanced", + "definio", + "illusio", + "3d_cartoon", + "colorful_anime", + "caricature", + "real", + "super_real", + "softy", + ], + tooltip="Processing engine selection.", + advanced=True, + ), + IO.DynamicCombo.Input( + "portrait_mode", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Combo.Input( + "portrait_style", + options=["standard", "pop", "super_pop"], + tooltip="Visual style applied to portrait images.", + ), + IO.Combo.Input( + "portrait_beautifier", + options=["none", "beautify_face", "beautify_face_max"], + tooltip="Facial beautification intensity on portraits.", + ), + ], + ), + ], + tooltip="Enable portrait mode for facial enhancements.", + ), + IO.Boolean.Input( + "fixed_generation", + default=True, + tooltip="When disabled, expect each generation to introduce a degree of randomness, " + "leading to more diverse outcomes.", + advanced=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + reference_image: Input.Image, + prompt: str, + style_strength: int, + structure_strength: int, + flavor: str, + engine: str, + portrait_mode: InputPortraitMode, + fixed_generation: bool, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if get_number_of_images(reference_image) != 1: + raise ValueError("Exactly one reference image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + validate_image_dimensions(reference_image, min_height=160, min_width=160) + + is_portrait = portrait_mode["portrait_mode"] == "enabled" + portrait_style = portrait_mode.get("portrait_style", "standard") + portrait_beautifier = portrait_mode.get("portrait_beautifier", "none") + + uploaded_urls = await upload_images_to_comfyapi(cls, [image, reference_image], max_images=2) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-style-transfer", method="POST"), + response_model=TaskResponse, + data=ImageStyleTransferRequest( + image=uploaded_urls[0], + reference_image=uploaded_urls[1], + prompt=prompt if prompt else None, + style_strength=style_strength, + structure_strength=structure_strength, + is_portrait=is_portrait, + portrait_style=portrait_style if is_portrait else None, + portrait_beautifier=portrait_beautifier if is_portrait and portrait_beautifier != "none" else None, + flavor=flavor, + engine=engine, + fixed_generation=fixed_generation, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-style-transfer/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageRelightNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageRelightNode", + display_name="Magnific Image Relight", + category="api node/image/Magnific", + description="Relight an image with lighting adjustments and optional reference-based light transfer.", + inputs=[ + IO.Image.Input("image", tooltip="The image to relight."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Descriptive guidance for lighting. Supports emphasis notation (1-1.4).", + ), + IO.Int.Input( + "light_transfer_strength", + min=0, + max=100, + default=100, + tooltip="Intensity of light transfer application.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "style", + options=[ + "standard", + "darker_but_realistic", + "clean", + "smooth", + "brighter", + "contrasted_n_hdr", + "just_composition", + ], + tooltip="Stylistic output preference.", + ), + IO.Boolean.Input( + "interpolate_from_original", + default=False, + tooltip="Restricts generation freedom to match original more closely.", + advanced=True, + ), + IO.Boolean.Input( + "change_background", + default=True, + tooltip="Modifies background based on prompt/reference.", + advanced=True, + ), + IO.Boolean.Input( + "preserve_details", + default=True, + tooltip="Maintains texture and fine details from original.", + advanced=True, + ), + IO.DynamicCombo.Input( + "advanced_settings", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Int.Input( + "whites", + min=0, + max=100, + default=50, + tooltip="Adjusts the brightest tones in the image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "blacks", + min=0, + max=100, + default=50, + tooltip="Adjusts the darkest tones in the image.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "brightness", + min=0, + max=100, + default=50, + tooltip="Overall brightness adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "contrast", + min=0, + max=100, + default=50, + tooltip="Contrast adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "saturation", + min=0, + max=100, + default=50, + tooltip="Color saturation adjustment.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Combo.Input( + "engine", + options=[ + "automatic", + "balanced", + "cool", + "real", + "illusio", + "fairy", + "colorful_anime", + "hard_transform", + "softy", + ], + tooltip="Processing engine selection.", + ), + IO.Combo.Input( + "transfer_light_a", + options=["automatic", "low", "medium", "normal", "high", "high_on_faces"], + tooltip="The intensity of light transfer.", + ), + IO.Combo.Input( + "transfer_light_b", + options=[ + "automatic", + "composition", + "straight", + "smooth_in", + "smooth_out", + "smooth_both", + "reverse_both", + "soft_in", + "soft_out", + "soft_mid", + # "strong_mid", # Commented out because requests fail when this is set. + "style_shift", + "strong_shift", + ], + tooltip="Also modifies light transfer intensity. " + "Can be combined with the previous control for varied effects.", + ), + IO.Boolean.Input( + "fixed_generation", + default=True, + tooltip="Ensures consistent output with the same settings.", + ), + ], + ), + ], + tooltip="Fine-tuning options for advanced lighting control.", + ), + IO.Image.Input( + "reference_image", + optional=True, + tooltip="Optional reference image to transfer lighting from.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + light_transfer_strength: int, + style: str, + interpolate_from_original: bool, + change_background: bool, + preserve_details: bool, + advanced_settings: InputAdvancedSettings, + reference_image: Input.Image | None = None, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if reference_image is not None and get_number_of_images(reference_image) != 1: + raise ValueError("Exactly one reference image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + if reference_image is not None: + validate_image_aspect_ratio(reference_image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(reference_image, min_height=160, min_width=160) + + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] + reference_url = None + if reference_image is not None: + reference_url = (await upload_images_to_comfyapi(cls, reference_image, max_images=1))[0] + + adv_settings = None + if advanced_settings["advanced_settings"] == "enabled": + adv_settings = ImageRelightAdvancedSettingsRequest( + whites=advanced_settings["whites"], + blacks=advanced_settings["blacks"], + brightness=advanced_settings["brightness"], + contrast=advanced_settings["contrast"], + saturation=advanced_settings["saturation"], + engine=advanced_settings["engine"], + transfer_light_a=advanced_settings["transfer_light_a"], + transfer_light_b=advanced_settings["transfer_light_b"], + fixed_generation=advanced_settings["fixed_generation"], + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/freepik/v1/ai/image-relight", method="POST"), + response_model=TaskResponse, + data=ImageRelightRequest( + image=image_url, + prompt=prompt if prompt else None, + transfer_light_from_reference_image=reference_url, + light_transfer_strength=light_transfer_strength, + interpolate_from_original=interpolate_from_original, + change_background=change_background, + style=style, + preserve_details=preserve_details, + advanced_settings=adv_settings, + ), + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-relight/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificImageSkinEnhancerNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MagnificImageSkinEnhancerNode", + display_name="Magnific Image Skin Enhancer", + category="api node/image/Magnific", + description="Skin enhancement for portraits with multiple processing modes.", + inputs=[ + IO.Image.Input("image", tooltip="The portrait image to enhance."), + IO.Int.Input( + "sharpen", + min=0, + max=100, + default=0, + tooltip="Sharpening intensity level.", + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "smart_grain", + min=0, + max=100, + default=2, + tooltip="Smart grain intensity level.", + display_mode=IO.NumberDisplay.slider, + ), + IO.DynamicCombo.Input( + "mode", + options=[ + IO.DynamicCombo.Option("creative", []), + IO.DynamicCombo.Option( + "faithful", + [ + IO.Int.Input( + "skin_detail", + min=0, + max=100, + default=80, + tooltip="Skin detail enhancement level.", + display_mode=IO.NumberDisplay.slider, + ), + ], + ), + IO.DynamicCombo.Option( + "flexible", + [ + IO.Combo.Input( + "optimized_for", + options=[ + "enhance_skin", + "improve_lighting", + "enhance_everything", + "transform_to_real", + "no_make_up", + ], + tooltip="Enhancement optimization target.", + ), + ], + ), + ], + tooltip="Processing mode: creative for artistic enhancement, " + "faithful for preserving original appearance, " + "flexible for targeted optimization.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode"]), + expr=""" + ( + $rates := {"creative": 0.29, "faithful": 0.37, "flexible": 0.45}; + {"type":"usd","usd": $lookup($rates, widgets.mode)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + sharpen: int, + smart_grain: int, + mode: InputSkinEnhancerMode, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + validate_image_aspect_ratio(image, (1, 3), (3, 1), strict=False) + validate_image_dimensions(image, min_height=160, min_width=160) + + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, total_pixels=4096 * 4096))[0] + selected_mode = mode["mode"] + + if selected_mode == "creative": + endpoint = "creative" + data = ImageSkinEnhancerCreativeRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + ) + elif selected_mode == "faithful": + endpoint = "faithful" + data = ImageSkinEnhancerFaithfulRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + skin_detail=mode["skin_detail"], + ) + else: # flexible + endpoint = "flexible" + data = ImageSkinEnhancerFlexibleRequest( + image=image_url, + sharpen=sharpen, + smart_grain=smart_grain, + optimized_for=mode["optimized_for"], + ) + + initial_res = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{endpoint}", method="POST"), + response_model=TaskResponse, + data=data, + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/freepik/v1/ai/skin-enhancer/{initial_res.task_id}"), + response_model=TaskResponse, + status_extractor=lambda x: x.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) + + +class MagnificExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + MagnificImageUpscalerCreativeNode, + MagnificImageUpscalerPreciseV2Node, + MagnificImageStyleTransferNode, + MagnificImageRelightNode, + MagnificImageSkinEnhancerNode, + ] + + +async def comfy_entrypoint() -> MagnificExtension: + return MagnificExtension() diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py new file mode 100644 index 000000000..3cf577f4a --- /dev/null +++ b/comfy_api_nodes/nodes_meshy.py @@ -0,0 +1,831 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.meshy import ( + InputShouldRemesh, + InputShouldTexture, + MeshyAnimationRequest, + MeshyAnimationResult, + MeshyImageToModelRequest, + MeshyModelResult, + MeshyMultiImageToModelRequest, + MeshyRefineTask, + MeshyRiggedResult, + MeshyRiggingRequest, + MeshyTaskResponse, + MeshyTextToModelRequest, + MeshyTextureRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_file_3d, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_string, +) + + +class MeshyTextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyTextToModelNode", + display_name="Meshy: Text to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.String.Input("prompt", multiline=True, default=""), + IO.Combo.Input("style", options=["realistic", "sculpture"]), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"], advanced=True), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + advanced=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.8}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + style: str, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, field_name="prompt", min_length=1, max_length=600) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyTextToModelRequest( + prompt=prompt, + art_style=style, + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + should_remesh=should_remesh["should_remesh"] == "true", + symmetry_mode=symmetry_mode, + pose_mode=pose_mode.lower(), + seed=seed, + ), + ) + task_id = response.result + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) + + +class MeshyRefineNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyRefineNode", + display_name="Meshy: Refine Draft Model", + category="api node/3d/Meshy", + description="Refine a previously created draft model.", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) in addition to the base color. " + "Note: this should be set to false when using Sculpture style, " + "as Sculpture style generates its own set of PBR maps.", + advanced=True, + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' may be used at the same time.", + optional=True, + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + meshy_task_id: str, + enable_pbr: bool, + texture_prompt: str, + texture_image: Input.Image | None = None, + ) -> IO.NodeOutput: + if texture_prompt and texture_image is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + texture_image_url = None + if texture_prompt: + validate_string(texture_prompt, field_name="texture_prompt", max_length=600) + if texture_image is not None: + texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v2/text-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyRefineTask( + preview_task_id=meshy_task_id, + enable_pbr=enable_pbr, + texture_prompt=texture_prompt if texture_prompt else None, + texture_image_url=texture_image_url, + ai_model=model, + ), + ) + task_id = response.result + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v2/text-to-3d/{task_id}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) + + +class MeshyImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyImageToModelNode", + display_name="Meshy: Image to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Image.Input("image"), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"]), + IO.DynamicCombo.Input( + "should_texture", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) " + "in addition to the base color.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' " + "may be used at the same time.", + optional=True, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Determines whether textures are generated. " + "Setting it to false skips the texture phase and returns a mesh without textures.", + ), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + advanced=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), + expr=""" + ( + $prices := {"true": 1.2, "false": 0.8}; + {"type":"usd","usd": $lookup($prices, widgets.should_texture)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + should_texture: InputShouldTexture, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + texture = should_texture["should_texture"] == "true" + texture_image_url = texture_prompt = None + if texture: + if should_texture["texture_prompt"] and should_texture["texture_image"] is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + if should_texture["texture_prompt"]: + validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) + texture_prompt = should_texture["texture_prompt"] + if should_texture["texture_image"] is not None: + texture_image_url = ( + await upload_images_to_comfyapi( + cls, should_texture["texture_image"], wait_label="Uploading texture" + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v1/image-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyImageToModelRequest( + image_url=(await upload_images_to_comfyapi(cls, image, wait_label="Uploading base image"))[0], + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + symmetry_mode=symmetry_mode, + should_remesh=should_remesh["should_remesh"] == "true", + should_texture=texture, + enable_pbr=should_texture.get("enable_pbr", None), + pose_mode=pose_mode.lower(), + texture_prompt=texture_prompt, + texture_image_url=texture_image_url, + seed=seed, + ), + ) + task_id = response.result + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/image-to-3d/{task_id}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) + + +class MeshyMultiImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyMultiImageToModelNode", + display_name="Meshy: Multi-Image to Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=2, max=4), + ), + IO.DynamicCombo.Input( + "should_remesh", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Combo.Input("topology", options=["triangle", "quad"]), + IO.Int.Input( + "target_polycount", + default=300000, + min=100, + max=300000, + display_mode=IO.NumberDisplay.number, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="When set to false, returns an unprocessed triangular mesh.", + ), + IO.Combo.Input("symmetry_mode", options=["auto", "on", "off"], advanced=True), + IO.DynamicCombo.Input( + "should_texture", + options=[ + IO.DynamicCombo.Option( + "true", + [ + IO.Boolean.Input( + "enable_pbr", + default=False, + tooltip="Generate PBR Maps (metallic, roughness, normal) " + "in addition to the base color.", + ), + IO.String.Input( + "texture_prompt", + default="", + multiline=True, + tooltip="Provide a text prompt to guide the texturing process. " + "Maximum 600 characters. Cannot be used at the same time as 'texture_image'.", + ), + IO.Image.Input( + "texture_image", + tooltip="Only one of 'texture_image' or 'texture_prompt' " + "may be used at the same time.", + optional=True, + ), + ], + ), + IO.DynamicCombo.Option("false", []), + ], + tooltip="Determines whether textures are generated. " + "Setting it to false skips the texture phase and returns a mesh without textures.", + ), + IO.Combo.Input( + "pose_mode", + options=["", "A-pose", "T-pose"], + tooltip="Specify the pose mode for the generated model.", + advanced=True, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MESHY_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["should_texture"]), + expr=""" + ( + $prices := {"true": 0.6, "false": 0.2}; + {"type":"usd","usd": $lookup($prices, widgets.should_texture)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + images: IO.Autogrow.Type, + should_remesh: InputShouldRemesh, + symmetry_mode: str, + should_texture: InputShouldTexture, + pose_mode: str, + seed: int, + ) -> IO.NodeOutput: + texture = should_texture["should_texture"] == "true" + texture_image_url = texture_prompt = None + if texture: + if should_texture["texture_prompt"] and should_texture["texture_image"] is not None: + raise ValueError("texture_prompt and texture_image cannot be used at the same time") + if should_texture["texture_prompt"]: + validate_string(should_texture["texture_prompt"], field_name="texture_prompt", max_length=600) + texture_prompt = should_texture["texture_prompt"] + if should_texture["texture_image"] is not None: + texture_image_url = ( + await upload_images_to_comfyapi( + cls, should_texture["texture_image"], wait_label="Uploading texture" + ) + )[0] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/meshy/openapi/v1/multi-image-to-3d", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyMultiImageToModelRequest( + image_urls=await upload_images_to_comfyapi( + cls, list(images.values()), wait_label="Uploading base images" + ), + ai_model=model, + topology=should_remesh.get("topology", None), + target_polycount=should_remesh.get("target_polycount", None), + symmetry_mode=symmetry_mode, + should_remesh=should_remesh["should_remesh"] == "true", + should_texture=texture, + enable_pbr=should_texture.get("enable_pbr", None), + pose_mode=pose_mode.lower(), + texture_prompt=texture_prompt, + texture_image_url=texture_image_url, + seed=seed, + ), + ) + task_id = response.result + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/multi-image-to-3d/{task_id}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) + + +class MeshyRigModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyRigModelNode", + display_name="Meshy: Rig Model", + category="api node/3d/Meshy", + description="Provides a rigged character in standard formats. " + "Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, " + "or humanoid assets with unclear limb and body structure.", + inputs=[ + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Float.Input( + "height_meters", + min=0.1, + max=15.0, + default=1.7, + tooltip="The approximate height of the character model in meters. " + "This aids in scaling and rigging accuracy.", + ), + IO.Image.Input( + "texture_image", + tooltip="The model's UV-unwrapped base color texture image.", + optional=True, + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MESHY_RIGGED_TASK_ID").Output(display_name="rig_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), + ) + + @classmethod + async def execute( + cls, + meshy_task_id: str, + height_meters: float, + texture_image: Input.Image | None = None, + ) -> IO.NodeOutput: + texture_image_url = None + if texture_image is not None: + texture_image_url = (await upload_images_to_comfyapi(cls, texture_image, wait_label="Uploading texture"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/rigging", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyRiggingRequest( + input_task_id=meshy_task_id, + height_meters=height_meters, + texture_image_url=texture_image_url, + ), + ) + task_id = response.result + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/rigging/{task_id}"), + response_model=MeshyRiggedResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.result.rigged_character_glb_url, "glb", task_id=task_id), + await download_url_to_file_3d(result.result.rigged_character_fbx_url, "fbx", task_id=task_id), + ) + + +class MeshyAnimateModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyAnimateModelNode", + display_name="Meshy: Animate Model", + category="api node/3d/Meshy", + description="Apply a specific animation action to a previously rigged character.", + inputs=[ + IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"), + IO.Int.Input( + "action_id", + default=0, + min=0, + max=696, + tooltip="Visit https://docs.meshy.ai/en/api/animation-library for a list of available values.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.12}""", + ), + ) + + @classmethod + async def execute( + cls, + rig_task_id: str, + action_id: int, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/animations", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyAnimationRequest( + rig_task_id=rig_task_id, + action_id=action_id, + ), + ) + task_id = response.result + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/animations/{task_id}"), + response_model=MeshyAnimationResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + return IO.NodeOutput( + f"{task_id}.glb", + await download_url_to_file_3d(result.result.animation_glb_url, "glb", task_id=task_id), + await download_url_to_file_3d(result.result.animation_fbx_url, "fbx", task_id=task_id), + ) + + +class MeshyTextureNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MeshyTextureNode", + display_name="Meshy: Texture Model", + category="api node/3d/Meshy", + inputs=[ + IO.Combo.Input("model", options=["latest"]), + IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), + IO.Boolean.Input( + "enable_original_uv", + default=True, + tooltip="Use the original UV of the model instead of generating new UVs. " + "When enabled, Meshy preserves existing textures from the uploaded model. " + "If the model has no original UV, the quality of the output might not be as good.", + advanced=True, + ), + IO.Boolean.Input("pbr", default=False, advanced=True), + IO.String.Input( + "text_style_prompt", + default="", + multiline=True, + tooltip="Describe your desired texture style of the object using text. Maximum 600 characters." + "Maximum 600 characters. Cannot be used at the same time as 'image_style'.", + ), + IO.Image.Input( + "image_style", + optional=True, + tooltip="A 2d image to guide the texturing process. " + "Can not be used at the same time with 'text_style_prompt'.", + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MODEL_TASK_ID").Output(display_name="meshy_task_id"), + IO.File3DGLB.Output(display_name="GLB"), + IO.File3DFBX.Output(display_name="FBX"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), + ) + + @classmethod + async def execute( + cls, + model: str, + meshy_task_id: str, + enable_original_uv: bool, + pbr: bool, + text_style_prompt: str, + image_style: Input.Image | None = None, + ) -> IO.NodeOutput: + if text_style_prompt and image_style is not None: + raise ValueError("text_style_prompt and image_style cannot be used at the same time") + if not text_style_prompt and image_style is None: + raise ValueError("Either text_style_prompt or image_style is required") + image_style_url = None + if image_style is not None: + image_style_url = (await upload_images_to_comfyapi(cls, image_style, wait_label="Uploading style"))[0] + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/meshy/openapi/v1/retexture", method="POST"), + response_model=MeshyTaskResponse, + data=MeshyTextureRequest( + input_task_id=meshy_task_id, + ai_model=model, + enable_original_uv=enable_original_uv, + enable_pbr=pbr, + text_style_prompt=text_style_prompt if text_style_prompt else None, + image_style_url=image_style_url, + ), + ) + task_id = response.result + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/meshy/openapi/v1/retexture/{task_id}"), + response_model=MeshyModelResult, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + return IO.NodeOutput( + f"{task_id}.glb", + task_id, + await download_url_to_file_3d(result.model_urls.glb, "glb", task_id=task_id), + await download_url_to_file_3d(result.model_urls.fbx, "fbx", task_id=task_id), + ) + + +class MeshyExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + MeshyTextToModelNode, + MeshyRefineNode, + MeshyImageToModelNode, + MeshyMultiImageToModelNode, + MeshyRigModelNode, + MeshyAnimateModelNode, + MeshyTextureNode, + ] + + +async def comfy_entrypoint() -> MeshyExtension: + return MeshyExtension() diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 05cbb700f..b5d0b461f 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -4,7 +4,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.minimax_api import ( +from comfy_api_nodes.apis.minimax import ( MinimaxFileRetrieveResponse, MiniMaxModel, MinimaxTaskResultResponse, @@ -134,6 +134,9 @@ class MinimaxTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.43}""", + ), ) @classmethod @@ -197,6 +200,9 @@ class MinimaxImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.43}""", + ), ) @classmethod @@ -340,6 +346,20 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["resolution", "duration"]), + expr=""" + ( + $prices := { + "768p": {"6": 0.28, "10": 0.56}, + "1080p": {"6": 0.49} + }; + $resPrices := $lookup($prices, $lowercase(widgets.resolution)); + $price := $lookup($resPrices, $string(widgets.duration)); + {"type":"usd","usd": $price ? $price : 0.43} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 2771e4790..78a230529 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -3,7 +3,7 @@ import logging from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.moonvalley import ( MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoRequest, @@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): ), IO.Int.Input( "steps", - default=33, - min=1, + default=80, + min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) max=100, step=1, tooltip="Number of denoising steps", @@ -233,6 +233,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 1.5}""", + ), ) @classmethod @@ -336,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): ), IO.Int.Input( "steps", - default=33, - min=1, + default=60, + min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24) max=100, step=1, display_mode=IO.NumberDisplay.number, @@ -351,6 +355,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 2.25}""", + ), ) @classmethod @@ -362,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): video: Input.Video | None = None, control_type: str = "Motion Transfer", motion_intensity: int | None = 100, - steps=33, + steps=60, prompt_adherence=4.5, ) -> IO.NodeOutput: validated_video = validate_video_to_video_input(video) @@ -457,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): ), IO.Int.Input( "steps", - default=33, - min=1, + default=80, + min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0) max=100, step=1, tooltip="Inference steps", @@ -471,6 +479,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 1.5}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index a6205a34f..4ee896fa8 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -10,24 +10,18 @@ from typing_extensions import override import folder_paths from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import ( - CreateModelResponseProperties, - Detail, - InputContent, +from comfy_api_nodes.apis.openai import ( InputFileContent, InputImageContent, InputMessage, - InputMessageContentList, InputTextContent, - Item, + ModelResponseProperties, OpenAICreateResponse, - OpenAIResponse, - OutputContent, -) -from comfy_api_nodes.apis.openai_api import ( OpenAIImageEditRequest, OpenAIImageGenerationRequest, OpenAIImageGenerationResponse, + OpenAIResponse, + OutputContent, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -49,7 +43,6 @@ class SupportedOpenAIModel(str, Enum): o1 = "o1" o3 = "o3" o1_pro = "o1-pro" - gpt_4o = "gpt-4o" gpt_4_1 = "gpt-4.1" gpt_4_1_mini = "gpt-4.1-mini" gpt_4_1_nano = "gpt-4.1-nano" @@ -160,6 +153,23 @@ class OpenAIDalle2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "n"]), + expr=""" + ( + $size := widgets.size; + $nRaw := widgets.n; + $n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1; + + $base := + $contains($size, "256x256") ? 0.016 : + $contains($size, "512x512") ? 0.018 : + 0.02; + + {"type":"usd","usd": $round($base * $n, 3)} + ) + """, + ), ) @classmethod @@ -249,7 +259,7 @@ class OpenAIDalle3(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -287,6 +297,25 @@ class OpenAIDalle3(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "quality"]), + expr=""" + ( + $size := widgets.size; + $q := widgets.quality; + $hd := $contains($q, "hd"); + + $price := + $contains($size, "1024x1024") + ? ($hd ? 0.08 : 0.04) + : (($contains($size, "1792x1024") or $contains($size, "1024x1792")) + ? ($hd ? 0.12 : 0.08) + : 0.04); + + {"type":"usd","usd": $price} + ) + """, + ), ) @classmethod @@ -334,9 +363,9 @@ class OpenAIGPTImage1(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIGPTImage1", - display_name="OpenAI GPT Image 1", + display_name="OpenAI GPT Image 1.5", category="api node/image/OpenAI", - description="Generates images synchronously via OpenAI's GPT Image 1 endpoint.", + description="Generates images synchronously via OpenAI's GPT Image endpoint.", inputs=[ IO.String.Input( "prompt", @@ -348,7 +377,7 @@ class OpenAIGPTImage1(IO.ComfyNode): "seed", default=0, min=0, - max=2 ** 31 - 1, + max=2**31 - 1, step=1, display_mode=IO.NumberDisplay.number, control_after_generate=True, @@ -399,6 +428,7 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Combo.Input( "model", options=["gpt-image-1", "gpt-image-1.5"], + default="gpt-image-1.5", optional=True, ), ], @@ -411,6 +441,28 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), + expr=""" + ( + $ranges := { + "low": [0.011, 0.02], + "medium": [0.046, 0.07], + "high": [0.167, 0.3] + }; + $range := $lookup($ranges, widgets.quality); + $n := widgets.n; + ($n = 1) + ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} + : { + "type":"range_usd", + "min_usd": $range[0], + "max_usd": $range[1], + "format": { "suffix": " x " & $string($n) & "/Run" } + } + ) + """, + ), ) @classmethod @@ -442,8 +494,8 @@ class OpenAIGPTImage1(IO.ComfyNode): files = [] batch_size = image.shape[0] for i in range(batch_size): - single_image = image[i: i + 1] - scaled_image = downscale_image_tensor(single_image, total_pixels=2048*2048).squeeze() + single_image = image[i : i + 1] + scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze() image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) @@ -465,7 +517,7 @@ class OpenAIGPTImage1(IO.ComfyNode): rgba_mask = torch.zeros(height, width, 4, device="cpu") rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu() - scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048*2048).squeeze() + scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0), total_pixels=2048 * 2048).squeeze() mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) @@ -523,6 +575,7 @@ class OpenAIChatNode(IO.ComfyNode): node_id="OpenAIChatNode", display_name="OpenAI ChatGPT", category="api node/text/OpenAI", + essentials_category="Text Generation", description="Generate text responses from an OpenAI model.", inputs=[ IO.String.Input( @@ -535,6 +588,7 @@ class OpenAIChatNode(IO.ComfyNode): "persist_context", default=False, tooltip="This parameter is deprecated and has no effect.", + advanced=True, ), IO.Combo.Input( "model", @@ -566,32 +620,90 @@ class OpenAIChatNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "o4-mini") ? { + "type": "list_usd", + "usd": [0.0011, 0.0044], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o1-pro") ? { + "type": "list_usd", + "usd": [0.15, 0.6], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o1") ? { + "type": "list_usd", + "usd": [0.015, 0.06], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o3-mini") ? { + "type": "list_usd", + "usd": [0.0011, 0.0044], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "o3") ? { + "type": "list_usd", + "usd": [0.01, 0.04], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1-nano") ? { + "type": "list_usd", + "usd": [0.0001, 0.0004], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1-mini") ? { + "type": "list_usd", + "usd": [0.0004, 0.0016], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-4.1") ? { + "type": "list_usd", + "usd": [0.002, 0.008], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5-nano") ? { + "type": "list_usd", + "usd": [0.00005, 0.0004], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5-mini") ? { + "type": "list_usd", + "usd": [0.00025, 0.002], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : $contains($m, "gpt-5") ? { + "type": "list_usd", + "usd": [0.00125, 0.01], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + : {"type": "text", "text": "Token-based"} + ) + """, + ), ) @classmethod - def get_message_content_from_response( - cls, response: OpenAIResponse - ) -> list[OutputContent]: + def get_message_content_from_response(cls, response: OpenAIResponse) -> list[OutputContent]: """Extract message content from the API response.""" for output in response.output: - if output.root.type == "message": - return output.root.content + if output.type == "message": + return output.content raise TypeError("No output message found in response") @classmethod - def get_text_from_message_content( - cls, message_content: list[OutputContent] - ) -> str: + def get_text_from_message_content(cls, message_content: list[OutputContent]) -> str: """Extract text content from message content.""" for content_item in message_content: - if content_item.root.type == "output_text": - return str(content_item.root.text) + if content_item.type == "output_text": + return str(content_item.text) return "No text output found in response" @classmethod - def tensor_to_input_image_content( - cls, image: torch.Tensor, detail_level: Detail = "auto" - ) -> InputImageContent: + def tensor_to_input_image_content(cls, image: torch.Tensor, detail_level: str = "auto") -> InputImageContent: """Convert a tensor to an input image content object.""" return InputImageContent( detail=detail_level, @@ -605,9 +717,9 @@ class OpenAIChatNode(IO.ComfyNode): prompt: str, image: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - ) -> InputMessageContentList: + ) -> list[InputTextContent | InputImageContent | InputFileContent]: """Create a list of input message contents from prompt and optional image.""" - content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ + content_list: list[InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: @@ -619,13 +731,9 @@ class OpenAIChatNode(IO.ComfyNode): type="input_image", ) ) - if files is not None: content_list.extend(files) - - return InputMessageContentList( - root=content_list, - ) + return content_list @classmethod async def execute( @@ -635,7 +743,7 @@ class OpenAIChatNode(IO.ComfyNode): model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, images: torch.Tensor | None = None, files: list[InputFileContent] | None = None, - advanced_options: CreateModelResponseProperties | None = None, + advanced_options: ModelResponseProperties | None = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) @@ -646,36 +754,28 @@ class OpenAIChatNode(IO.ComfyNode): response_model=OpenAIResponse, data=OpenAICreateResponse( input=[ - Item( - root=InputMessage( - content=cls.create_input_message_contents( - prompt, images, files - ), - role="user", - ) + InputMessage( + content=cls.create_input_message_contents(prompt, images, files), + role="user", ), ], store=True, stream=False, model=model, previous_response_id=None, - **( - advanced_options.model_dump(exclude_none=True) - if advanced_options - else {} - ), + **(advanced_options.model_dump(exclude_none=True) if advanced_options else {}), ), ) response_id = create_response.id # Get result output result_response = await poll_op( - cls, - ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), - response_model=OpenAIResponse, - status_extractor=lambda response: response.status, - completed_statuses=["incomplete", "completed"] - ) + cls, + ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), + response_model=OpenAIResponse, + status_extractor=lambda response: response.status, + completed_statuses=["incomplete", "completed"], + ) return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) @@ -758,6 +858,7 @@ class OpenAIChatConfig(IO.ComfyNode): options=["auto", "disabled"], default="auto", tooltip="The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error", + advanced=True, ), IO.Int.Input( "max_output_tokens", @@ -766,6 +867,7 @@ class OpenAIChatConfig(IO.ComfyNode): max=16384, tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens", optional=True, + advanced=True, ), IO.String.Input( "instructions", @@ -796,7 +898,7 @@ class OpenAIChatConfig(IO.ComfyNode): remove depending on model choice. """ return IO.NodeOutput( - CreateModelResponseProperties( + ModelResponseProperties( instructions=instructions, truncation=truncation, max_output_tokens=max_output_tokens, diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 6e1686af0..e17a24ae7 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,7 +1,7 @@ import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.pixverse_api import ( +from comfy_api_nodes.apis.pixverse import ( PixverseTextVideoRequest, PixverseImageVideoRequest, PixverseTransitionVideoRequest, @@ -128,6 +128,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -242,6 +243,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -355,6 +357,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=PRICE_BADGE_VIDEO, ) @classmethod @@ -416,6 +419,33 @@ class PixverseTransitionVideoNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) +PRICE_BADGE_VIDEO = IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds", "quality", "motion_mode"]), + expr=""" + ( + $prices := { + "5": { + "1080p": {"normal": 1.2, "fast": 1.2}, + "720p": {"normal": 0.6, "fast": 1.2}, + "540p": {"normal": 0.45, "fast": 0.9}, + "360p": {"normal": 0.45, "fast": 0.9} + }, + "8": { + "1080p": {"normal": 1.2, "fast": 1.2}, + "720p": {"normal": 1.2, "fast": 1.2}, + "540p": {"normal": 0.9, "fast": 1.2}, + "360p": {"normal": 0.9, "fast": 1.2} + } + }; + $durPrices := $lookup($prices, $string(widgets.duration_seconds)); + $qualityPrices := $lookup($durPrices, widgets.quality); + $price := $lookup($qualityPrices, widgets.motion_mode); + {"type":"usd","usd": $price ? $price : 0.9} + ) + """, +) + + class PixVerseExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py new file mode 100644 index 000000000..61533263f --- /dev/null +++ b/comfy_api_nodes/nodes_quiver.py @@ -0,0 +1,291 @@ +from io import BytesIO + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.quiver import ( + QuiverImageObject, + QuiverImageToSVGRequest, + QuiverSVGResponse, + QuiverTextToSVGRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + sync_op, + upload_image_to_comfyapi, + validate_string, +) +from comfy_extras.nodes_images import SVG + + +class QuiverTextToSVGNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="QuiverTextToSVGNode", + display_name="Quiver Text to SVG", + category="api node/image/Quiver", + description="Generate an SVG from a text prompt using Quiver AI.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired SVG output.", + ), + IO.String.Input( + "instructions", + multiline=True, + default="", + tooltip="Additional style or formatting guidance.", + optional=True, + ), + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="ref_", + min=0, + max=4, + ), + tooltip="Up to 4 reference images to guide the generation.", + optional=True, + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "arrow-preview", + [ + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Randomness control. Higher values increase randomness.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=1.0, + min=0.05, + max=1.0, + step=0.05, + display_mode=IO.NumberDisplay.slider, + tooltip="Nucleus sampling parameter.", + advanced=True, + ), + IO.Float.Input( + "presence_penalty", + default=0.0, + min=-2.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Token presence penalty.", + advanced=True, + ), + ], + ), + ], + tooltip="Model to use for SVG generation.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.429}""", + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + instructions: str = None, + reference_images: IO.Autogrow.Type = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1) + + references = None + if reference_images: + references = [] + for key in reference_images: + url = await upload_image_to_comfyapi(cls, reference_images[key]) + references.append(QuiverImageObject(url=url)) + if len(references) > 4: + raise ValueError("Maximum 4 reference images are allowed.") + + instructions_val = instructions.strip() if instructions else None + if instructions_val == "": + instructions_val = None + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"), + response_model=QuiverSVGResponse, + data=QuiverTextToSVGRequest( + model=model["model"], + prompt=prompt, + instructions=instructions_val, + references=references, + temperature=model.get("temperature"), + top_p=model.get("top_p"), + presence_penalty=model.get("presence_penalty"), + ), + ) + + svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data] + return IO.NodeOutput(SVG(svg_data)) + + +class QuiverImageToSVGNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="QuiverImageToSVGNode", + display_name="Quiver Image to SVG", + category="api node/image/Quiver", + description="Vectorize a raster image into SVG using Quiver AI.", + inputs=[ + IO.Image.Input( + "image", + tooltip="Input image to vectorize.", + ), + IO.Boolean.Input( + "auto_crop", + default=False, + tooltip="Automatically crop to the dominant subject.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "arrow-preview", + [ + IO.Int.Input( + "target_size", + default=1024, + min=128, + max=4096, + tooltip="Square resize target in pixels.", + ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Randomness control. Higher values increase randomness.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=1.0, + min=0.05, + max=1.0, + step=0.05, + display_mode=IO.NumberDisplay.slider, + tooltip="Nucleus sampling parameter.", + advanced=True, + ), + IO.Float.Input( + "presence_penalty", + default=0.0, + min=-2.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Token presence penalty.", + advanced=True, + ), + ], + ), + ], + tooltip="Model to use for SVG vectorization.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.429}""", + ), + ) + + @classmethod + async def execute( + cls, + image, + auto_crop: bool, + model: dict, + seed: int, + ) -> IO.NodeOutput: + image_url = await upload_image_to_comfyapi(cls, image) + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"), + response_model=QuiverSVGResponse, + data=QuiverImageToSVGRequest( + model=model["model"], + image=QuiverImageObject(url=image_url), + auto_crop=auto_crop if auto_crop else None, + target_size=model.get("target_size"), + temperature=model.get("temperature"), + top_p=model.get("top_p"), + presence_penalty=model.get("presence_penalty"), + ), + ) + + svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data] + return IO.NodeOutput(SVG(svg_data)) + + +class QuiverExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + QuiverTextToSVGNode, + QuiverImageToSVGNode, + ] + + +async def comfy_entrypoint() -> QuiverExtension: + return QuiverExtension() diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index e3440b946..c60cfbc4a 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,5 +1,4 @@ from io import BytesIO -from typing import Optional, Union import aiohttp import torch @@ -8,15 +7,18 @@ from typing_extensions import override from comfy.utils import ProgressBar from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.recraft_api import ( +from comfy_api_nodes.apis.recraft import ( + RECRAFT_V4_PRO_SIZES, + RECRAFT_V4_SIZES, RecraftColor, RecraftColorChain, RecraftControls, + RecraftCreateStyleRequest, + RecraftCreateStyleResponse, RecraftImageGenerationRequest, RecraftImageGenerationResponse, RecraftImageSize, RecraftIO, - RecraftModel, RecraftStyle, RecraftStyleV3, get_v3_substyles, @@ -37,7 +39,7 @@ async def handle_recraft_file_request( cls: type[IO.ComfyNode], image: torch.Tensor, path: str, - mask: Optional[torch.Tensor] = None, + mask: torch.Tensor | None = None, total_pixels: int = 4096 * 4096, timeout: int = 1024, request=None, @@ -71,11 +73,11 @@ async def handle_recraft_file_request( def recraft_multipart_parser( data, parent_key=None, - formatter: Optional[type[callable]] = None, - converted_to_check: Optional[list[list]] = None, + formatter: type[callable] | None = None, + converted_to_check: list[list] | None = None, is_list: bool = False, return_mode: str = "formdata", # "dict" | "formdata" -) -> Union[dict, aiohttp.FormData]: +) -> dict | aiohttp.FormData: """ Formats data such that multipart/form-data will work with aiohttp library when both files and data are present. @@ -307,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): node_id="RecraftStyleV3InfiniteStyleLibrary", display_name="Recraft Style - Infinite Style Library", category="api node/image/Recraft", - description="Select style based on preexisting UUID from Recraft's Infinite Style Library.", + description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.", inputs=[ IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), ], @@ -323,6 +325,75 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): return IO.NodeOutput(RecraftStyle(style_id=style_id)) +class RecraftCreateStyleNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftCreateStyleNode", + display_name="Recraft Create Style", + category="api node/image/Recraft", + description="Create a custom style from reference images. " + "Upload 1-5 images to use as style references. " + "Total size of all images is limited to 5 MB.", + inputs=[ + IO.Combo.Input( + "style", + options=["realistic_image", "digital_illustration"], + tooltip="The base style of the generated images.", + ), + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="image", + min=1, + max=5, + ), + ), + ], + outputs=[ + IO.String.Output(display_name="style_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd": 0.04}""", + ), + ) + + @classmethod + async def execute( + cls, + style: str, + images: IO.Autogrow.Type, + ) -> IO.NodeOutput: + files = [] + total_size = 0 + max_total_size = 5 * 1024 * 1024 # 5 MB limit + for i, img in enumerate(list(images.values())): + file_bytes = tensor_to_bytesio(img, total_pixels=2048 * 2048, mime_type="image/webp").read() + total_size += len(file_bytes) + if total_size > max_total_size: + raise Exception("Total size of all images exceeds 5 MB limit.") + files.append((f"file{i + 1}", file_bytes)) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/recraft/styles", method="POST"), + response_model=RecraftCreateStyleResponse, + files=files, + data=RecraftCreateStyleRequest(style=style), + content_type="multipart/form-data", + max_retries=1, + ) + + return IO.NodeOutput(response.id) + + class RecraftTextToImageNode(IO.ComfyNode): @classmethod def define_schema(cls): @@ -378,6 +449,10 @@ class RecraftTextToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -391,7 +466,7 @@ class RecraftTextToImageNode(IO.ComfyNode): negative_prompt: str = None, recraft_controls: RecraftControls = None, ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False, max_length=1000) + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: recraft_style = default_style @@ -410,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode): data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", size=size, n=n, style=recraft_style.style, @@ -490,6 +565,10 @@ class RecraftImageToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -519,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode): request = RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", n=n, strength=round(strength, 2), style=recraft_style.style, @@ -591,6 +670,10 @@ class RecraftImageInpaintingNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.04 * widgets.n, 2)}""", + ), ) @classmethod @@ -615,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode): request = RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", n=n, style=recraft_style.style, substyle=recraft_style.substyle, @@ -692,6 +775,10 @@ class RecraftTextToVectorNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["n"]), + expr="""{"type":"usd","usd": $round(0.08 * widgets.n, 2)}""", + ), ) @classmethod @@ -723,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode): data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", size=size, n=n, style=recraft_style.style, @@ -746,6 +833,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode): node_id="RecraftVectorizeImageNode", display_name="Recraft Vectorize Image", category="api node/image/Recraft", + essentials_category="Image Tools", description="Generates SVG synchronously from an input image.", inputs=[ IO.Image.Input("image"), @@ -759,6 +847,10 @@ class RecraftVectorizeImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(), + expr="""{"type":"usd","usd": 0.01}""", + ), ) @classmethod @@ -817,6 +909,9 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.04}""", + ), ) @classmethod @@ -839,7 +934,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): request = RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, - model=RecraftModel.recraftv3, + model="recraftv3", n=n, style=recraft_style.style, substyle=recraft_style.substyle, @@ -869,6 +964,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode): node_id="RecraftRemoveBackgroundNode", display_name="Recraft Remove Background", category="api node/image/Recraft", + essentials_category="Image Tools", description="Remove background from image, and return processed image and mask.", inputs=[ IO.Image.Input("image"), @@ -883,6 +979,9 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.01}""", + ), ) @classmethod @@ -929,6 +1028,9 @@ class RecraftCrispUpscaleNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.004}""", + ), ) @classmethod @@ -972,9 +1074,258 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) +class RecraftV4TextToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftV4TextToImageNode", + display_name="Recraft V4 Text to Image", + category="api node/image/Recraft", + description="Generates images using Recraft V4 or V4 Pro models.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt for the image generation. Maximum 10,000 characters.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "recraftv4", + [ + IO.Combo.Input( + "size", + options=RECRAFT_V4_SIZES, + default="1024x1024", + tooltip="The size of the generated image.", + ), + ], + ), + IO.DynamicCombo.Option( + "recraftv4_pro", + [ + IO.Combo.Input( + "size", + options=RECRAFT_V4_PRO_SIZES, + default="2048x2048", + tooltip="The size of the generated image.", + ), + ], + ), + ], + tooltip="The model to use for generation.", + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]), + expr=""" + ( + $prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25}; + {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + model: dict, + n: int, + seed: int, + recraft_controls: RecraftControls | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + model=model["model"], + size=model["size"], + n=n, + controls=recraft_controls.create_api_model() if recraft_controls else None, + ), + max_retries=1, + ) + images = [] + for data in response.data: + with handle_recraft_image_output(): + image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024)) + if len(image.shape) < 4: + image = image.unsqueeze(0) + images.append(image) + return IO.NodeOutput(torch.cat(images, dim=0)) + + +class RecraftV4TextToVectorNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftV4TextToVectorNode", + display_name="Recraft V4 Text to Vector", + category="api node/image/Recraft", + description="Generates SVG using Recraft V4 or V4 Pro models.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt for the image generation. Maximum 10,000 characters.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + tooltip="An optional text description of undesired elements on an image.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "recraftv4", + [ + IO.Combo.Input( + "size", + options=RECRAFT_V4_SIZES, + default="1024x1024", + tooltip="The size of the generated image.", + ), + ], + ), + IO.DynamicCombo.Option( + "recraftv4_pro", + [ + IO.Combo.Input( + "size", + options=RECRAFT_V4_PRO_SIZES, + default="2048x2048", + tooltip="The size of the generated image.", + ), + ], + ), + ], + tooltip="The model to use for generation.", + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]), + expr=""" + ( + $prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30}; + {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + model: dict, + n: int, + seed: int, + recraft_controls: RecraftControls | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( + prompt=prompt, + negative_prompt=negative_prompt if negative_prompt else None, + model=model["model"], + size=model["size"], + n=n, + style="vector_illustration", + substyle=None, + controls=recraft_controls.create_api_model() if recraft_controls else None, + ), + max_retries=1, + ) + svg_data = [] + for data in response.data: + svg_data.append(await download_url_as_bytesio(data.url, timeout=1024)) + return IO.NodeOutput(SVG(svg_data)) + + class RecraftExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -992,8 +1343,11 @@ class RecraftExtension(ComfyExtension): RecraftStyleV3DigitalIllustrationNode, RecraftStyleV3LogoRasterNode, RecraftStyleInfiniteStyleLibrary, + RecraftCreateStyleNode, RecraftColorRGBNode, RecraftControlsNode, + RecraftV4TextToImageNode, + RecraftV4TextToVectorNode, ] diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py new file mode 100644 index 000000000..a87395394 --- /dev/null +++ b/comfy_api_nodes/nodes_reve.py @@ -0,0 +1,424 @@ +from io import BytesIO + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.reve import ( + ReveImageCreateRequest, + ReveImageEditRequest, + ReveImageRemixRequest, + RevePostprocessingOperation, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + bytesio_to_image_tensor, + sync_op_raw, + tensor_to_base64_string, + validate_string, +) + + +def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None: + ops = [] + if upscale["upscale"] == "enabled": + ops.append( + RevePostprocessingOperation( + process="upscale", + upscale_factor=upscale["upscale_factor"], + ) + ) + if remove_background: + ops.append(RevePostprocessingOperation(process="remove_background")) + return ops or None + + +def _postprocessing_inputs(): + return [ + IO.DynamicCombo.Input( + "upscale", + options=[ + IO.DynamicCombo.Option("disabled", []), + IO.DynamicCombo.Option( + "enabled", + [ + IO.Int.Input( + "upscale_factor", + default=2, + min=2, + max=4, + step=1, + tooltip="Upscale factor (2x, 3x, or 4x).", + ), + ], + ), + ], + tooltip="Upscale the generated image. May add additional cost.", + ), + IO.Boolean.Input( + "remove_background", + default=False, + tooltip="Remove the background from the generated image. May add additional cost.", + ), + ] + + +def _reve_price_extractor(headers: dict) -> float | None: + credits_used = headers.get("x-reve-credits-used") + if credits_used is not None: + return float(credits_used) / 524.48 + return None + + +def _reve_response_header_validator(headers: dict) -> None: + error_code = headers.get("x-reve-error-code") + if error_code: + raise ValueError(f"Reve API error: {error_code}") + if headers.get("x-reve-content-violation", "").lower() == "true": + raise ValueError("The generated image was flagged for content policy violation.") + + +def _model_inputs(versions: list[str], aspect_ratios: list[str]): + return [ + IO.DynamicCombo.Option( + version, + [ + IO.Combo.Input( + "aspect_ratio", + options=aspect_ratios, + tooltip="Aspect ratio of the output image.", + ), + IO.Int.Input( + "test_time_scaling", + default=1, + min=1, + max=5, + step=1, + tooltip="Higher values produce better images but cost more credits.", + advanced=True, + ), + ], + ) + for version in versions + ] + + +class ReveImageCreateNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageCreateNode", + display_name="Reve Image Create", + category="api node/image/Reve", + description="Generate images from text descriptions using Reve.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-create@20250915"], + aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for generation.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["upscale", "upscale.upscale_factor"], + ), + expr=""" + ( + $factor := $lookup(widgets, "upscale.upscale_factor"); + $fmt := {"approximate": true, "note": "(base)"}; + widgets.upscale = "enabled" ? ( + $factor = 4 ? {"type": "usd", "usd": 0.0762, "format": $fmt} + : $factor = 3 ? {"type": "usd", "usd": 0.0591, "format": $fmt} + : {"type": "usd", "usd": 0.0457, "format": $fmt} + ) : {"type": "usd", "usd": 0.03432, "format": $fmt} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/create", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageCreateRequest( + prompt=prompt, + aspect_ratio=model["aspect_ratio"], + version=model["model"], + test_time_scaling=model["test_time_scaling"], + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageEditNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageEditNode", + display_name="Reve Image Edit", + category="api node/image/Reve", + description="Edit images using natural language instructions with Reve.", + inputs=[ + IO.Image.Input("image", tooltip="The image to edit."), + IO.String.Input( + "edit_instruction", + multiline=True, + default="", + tooltip="Text description of how to edit the image. Maximum 2560 characters.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-edit@20250915", "reve-edit-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for editing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model", "upscale", "upscale.upscale_factor"], + ), + expr=""" + ( + $fmt := {"approximate": true, "note": "(base)"}; + $isFast := $contains(widgets.model, "fast"); + $enabled := widgets.upscale = "enabled"; + $factor := $lookup(widgets, "upscale.upscale_factor"); + $isFast + ? {"type": "usd", "usd": 0.01001, "format": $fmt} + : $enabled ? ( + $factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt} + : $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt} + : {"type": "usd", "usd": 0.0686, "format": $fmt} + ) : {"type": "usd", "usd": 0.0572, "format": $fmt} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + edit_instruction: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(edit_instruction, min_length=1, max_length=2560) + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/edit", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageEditRequest( + edit_instruction=edit_instruction, + reference_image=tensor_to_base64_string(image), + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveImageRemixNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ReveImageRemixNode", + display_name="Reve Image Remix", + category="api node/image/Reve", + description="Combine reference images with text prompts to create new images using Reve.", + inputs=[ + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="image_", + min=1, + max=6, + ), + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired image. " + "May include XML img tags to reference specific images by index, " + "e.g. 0, 1, etc.", + ), + IO.DynamicCombo.Input( + "model", + options=_model_inputs( + ["reve-remix@20250915", "reve-remix-fast@20251030"], + aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"], + ), + tooltip="Model version to use for remixing.", + ), + *_postprocessing_inputs(), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model", "upscale", "upscale.upscale_factor"], + ), + expr=""" + ( + $fmt := {"approximate": true, "note": "(base)"}; + $isFast := $contains(widgets.model, "fast"); + $enabled := widgets.upscale = "enabled"; + $factor := $lookup(widgets, "upscale.upscale_factor"); + $isFast + ? {"type": "usd", "usd": 0.01001, "format": $fmt} + : $enabled ? ( + $factor = 4 ? {"type": "usd", "usd": 0.0991, "format": $fmt} + : $factor = 3 ? {"type": "usd", "usd": 0.0819, "format": $fmt} + : {"type": "usd", "usd": 0.0686, "format": $fmt} + ) : {"type": "usd", "usd": 0.0572, "format": $fmt} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + reference_images: IO.Autogrow.Type, + prompt: str, + model: dict, + upscale: dict, + remove_background: bool, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2560) + if not reference_images: + raise ValueError("At least one reference image is required.") + ref_base64_list = [] + for key in reference_images: + ref_base64_list.append(tensor_to_base64_string(reference_images[key])) + if len(ref_base64_list) > 6: + raise ValueError("Maximum 6 reference images are allowed.") + tts = model["test_time_scaling"] + ar = model["aspect_ratio"] + response = await sync_op_raw( + cls, + ApiEndpoint( + path="/proxy/reve/v1/image/remix", + method="POST", + headers={"Accept": "image/webp"}, + ), + as_binary=True, + price_extractor=_reve_price_extractor, + response_header_validator=_reve_response_header_validator, + data=ReveImageRemixRequest( + prompt=prompt, + reference_images=ref_base64_list, + aspect_ratio=ar if ar != "auto" else None, + version=model["model"], + test_time_scaling=tts if tts and tts > 1 else None, + postprocessing=_build_postprocessing(upscale, remove_background), + ), + ) + return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response))) + + +class ReveExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ReveImageCreateNode, + ReveImageEditNode, + ReveImageRemixNode, + ] + + +async def comfy_entrypoint() -> ReveExtension: + return ReveExtension() diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index e60e7a6d6..2b829b8db 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -10,11 +10,10 @@ import folder_paths as comfy_paths import os import logging import math -from typing import Optional from io import BytesIO from typing_extensions import override from PIL import Image -from comfy_api_nodes.apis.rodin_api import ( +from comfy_api_nodes.apis.rodin import ( Rodin3DGenerateRequest, Rodin3DGenerateResponse, Rodin3DCheckStatusRequest, @@ -28,8 +27,9 @@ from comfy_api_nodes.util import ( poll_op, ApiEndpoint, download_url_to_bytesio, + download_url_to_file_3d, ) -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import ComfyExtension, IO, Types COMMON_PARAMETERS = [ @@ -177,7 +177,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: return "DONE" return "Generating" -def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]: +def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None: if not response.jobs: return None completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done) @@ -207,17 +207,25 @@ async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3D ) -async def download_files(url_list, task_uuid: str): +async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.File3D | None]: result_folder_name = f"Rodin3D_{task_uuid}" save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) os.makedirs(save_path, exist_ok=True) model_file_path = None + file_3d = None + for i in url_list.list: file_path = os.path.join(save_path, i.name) - if file_path.endswith(".glb"): + if i.name.lower().endswith(".glb"): model_file_path = os.path.join(result_folder_name, i.name) - await download_url_to_bytesio(i.url, file_path) - return model_file_path + file_3d = await download_url_to_file_3d(i.url, "glb") + # Save to disk for backward compatibility + with open(file_path, "wb") as f: + f.write(file_3d.get_bytes()) + else: + await download_url_to_bytesio(i.url, file_path) + + return model_file_path, file_3d class Rodin3D_Regular(IO.ComfyNode): @@ -234,13 +242,19 @@ class Rodin3D_Regular(IO.ComfyNode): IO.Image.Input("Images"), *COMMON_PARAMETERS, ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -268,9 +282,9 @@ class Rodin3D_Regular(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Detail(IO.ComfyNode): @@ -287,13 +301,19 @@ class Rodin3D_Detail(IO.ComfyNode): IO.Image.Input("Images"), *COMMON_PARAMETERS, ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -321,9 +341,9 @@ class Rodin3D_Detail(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Smooth(IO.ComfyNode): @@ -340,13 +360,19 @@ class Rodin3D_Smooth(IO.ComfyNode): IO.Image.Input("Images"), *COMMON_PARAMETERS, ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -373,9 +399,9 @@ class Rodin3D_Smooth(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Sketch(IO.ComfyNode): @@ -399,13 +425,19 @@ class Rodin3D_Sketch(IO.ComfyNode): optional=True, ), ], - outputs=[IO.String.Output(display_name="3D Model Path")], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), + ], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -429,9 +461,9 @@ class Rodin3D_Sketch(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3D_Gen2(IO.ComfyNode): @@ -461,15 +493,21 @@ class Rodin3D_Gen2(IO.ComfyNode): default="500K-Triangle", optional=True, ), - IO.Boolean.Input("TAPose", default=False), + IO.Boolean.Input("TAPose", default=False, advanced=True), + ], + outputs=[ + IO.String.Output(display_name="3D Model Path"), # for backward compatibility only + IO.File3DGLB.Output(display_name="GLB"), ], - outputs=[IO.String.Output(display_name="3D Model Path")], hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -499,9 +537,9 @@ class Rodin3D_Gen2(IO.ComfyNode): ) await poll_for_task_status(subscription_key, cls) download_list = await get_rodin_download_list(task_uuid, cls) - model = await download_files(download_list, task_uuid) + model_path, file_3d = await download_files(download_list, task_uuid) - return IO.NodeOutput(model) + return IO.NodeOutput(model_path, file_3d) class Rodin3DExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 3c55039c9..573170ba2 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -16,7 +16,7 @@ from enum import Enum from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis import ( +from comfy_api_nodes.apis.runway import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, @@ -184,6 +184,10 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -274,6 +278,10 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -372,6 +380,10 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration"]), + expr="""{"type":"usd","usd": 0.0715 * widgets.duration}""", + ), ) @classmethod @@ -457,6 +469,9 @@ class RunwayTextToImageNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.11}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index 92b225d40..afc18bb25 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -89,6 +89,24 @@ class OpenAIVideoSora2(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "size", "duration"]), + expr=""" + ( + $m := widgets.model; + $size := widgets.size; + $dur := widgets.duration; + $isPro := $contains($m, "sora-2-pro"); + $isSora2 := $contains($m, "sora-2"); + $isProSize := ($size = "1024x1792" or $size = "1792x1024"); + $perSec := + $isPro ? ($isProSize ? 0.5 : 0.3) : + $isSora2 ? 0.1 : + ($isProSize ? 0.5 : 0.1); + {"type":"usd","usd": $round($perSec * $dur, 2)} + ) + """, + ), ) @classmethod @@ -131,7 +149,6 @@ class OpenAIVideoSora2(IO.ComfyNode): response_model=Sora2GenerationResponse, status_extractor=lambda x: x.status, poll_interval=8.0, - max_poll_attempts=160, estimated_duration=int(45 * (duration / 4) * model_time_multiplier), ) return IO.NodeOutput( diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index bb7ceed78..9ef13c83b 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -3,7 +3,7 @@ from typing import Optional from typing_extensions import override from comfy_api.latest import ComfyExtension, Input, IO -from comfy_api_nodes.apis.stability_api import ( +from comfy_api_nodes.apis.stability import ( StabilityUpscaleConservativeRequest, StabilityUpscaleCreativeRequest, StabilityAsyncResponse, @@ -86,6 +86,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode): "style_preset", options=get_stability_style_presets(), tooltip="Optional desired style of generated image.", + advanced=True, ), IO.Int.Input( "seed", @@ -107,6 +108,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode): tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.", force_input=True, optional=True, + advanced=True, ), IO.Float.Input( "image_denoise", @@ -127,6 +129,9 @@ class StabilityStableImageUltraNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.08}""", + ), ) @classmethod @@ -215,6 +220,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): "style_preset", options=get_stability_style_presets(), tooltip="Optional desired style of generated image.", + advanced=True, ), IO.Float.Input( "cfg_scale", @@ -244,6 +250,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", force_input=True, optional=True, + advanced=True, ), IO.Float.Input( "image_denoise", @@ -264,6 +271,16 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $contains(widgets.model,"large") + ? {"type":"usd","usd":0.065} + : {"type":"usd","usd":0.035} + ) + """, + ), ) @classmethod @@ -371,6 +388,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", force_input=True, optional=True, + advanced=True, ), ], outputs=[ @@ -382,6 +400,9 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -458,6 +479,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): "style_preset", options=get_stability_style_presets(), tooltip="Optional desired style of generated image.", + advanced=True, ), IO.Int.Input( "seed", @@ -475,6 +497,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.", force_input=True, optional=True, + advanced=True, ), ], outputs=[ @@ -486,6 +509,9 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -566,6 +592,9 @@ class StabilityUpscaleFastNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.01}""", + ), ) @classmethod @@ -602,6 +631,7 @@ class StabilityTextToAudio(IO.ComfyNode): node_id="StabilityTextToAudio", display_name="Stability AI Text To Audio", category="api node/audio/Stability AI", + essentials_category="Audio", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Combo.Input( @@ -637,6 +667,7 @@ class StabilityTextToAudio(IO.ComfyNode): step=1, tooltip="Controls the number of sampling steps.", optional=True, + advanced=True, ), ], outputs=[ @@ -648,6 +679,9 @@ class StabilityTextToAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod @@ -711,6 +745,7 @@ class StabilityAudioToAudio(IO.ComfyNode): step=1, tooltip="Controls the number of sampling steps.", optional=True, + advanced=True, ), IO.Float.Input( "strength", @@ -732,6 +767,9 @@ class StabilityAudioToAudio(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod @@ -801,6 +839,7 @@ class StabilityAudioInpaint(IO.ComfyNode): step=1, tooltip="Controls the number of sampling steps.", optional=True, + advanced=True, ), IO.Int.Input( "mask_start", @@ -809,6 +848,7 @@ class StabilityAudioInpaint(IO.ComfyNode): max=190, step=1, optional=True, + advanced=True, ), IO.Int.Input( "mask_end", @@ -817,6 +857,7 @@ class StabilityAudioInpaint(IO.ComfyNode): max=190, step=1, optional=True, + advanced=True, ), ], outputs=[ @@ -828,6 +869,9 @@ class StabilityAudioInpaint(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.2}""", + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index f522756e5..6b61bd4b2 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -2,11 +2,27 @@ import builtins from io import BytesIO import aiohttp -import torch from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input -from comfy_api_nodes.apis import topaz_api +from comfy_api_nodes.apis.topaz import ( + CreateVideoRequest, + CreateVideoRequestSource, + CreateVideoResponse, + ImageAsyncTaskResponse, + ImageDownloadResponse, + ImageEnhanceRequest, + ImageStatusResponse, + OutputInformationVideo, + Resolution, + VideoAcceptResponse, + VideoCompleteUploadRequest, + VideoCompleteUploadRequestPart, + VideoCompleteUploadResponse, + VideoEnhancementFilter, + VideoFrameInterpolationFilter, + VideoStatusResponse, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, @@ -23,10 +39,6 @@ UPSCALER_MODELS_MAP = { "Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Creative": "slc-1", } -UPSCALER_VALUES_MAP = { - "FullHD (1080p)": 1920, - "4K (2160p)": 3840, -} class TopazImageEnhance(IO.ComfyNode): @@ -51,12 +63,14 @@ class TopazImageEnhance(IO.ComfyNode): "subject_detection", options=["All", "Foreground", "Background"], optional=True, + advanced=True, ), IO.Boolean.Input( "face_enhancement", default=True, optional=True, tooltip="Enhance faces (if present) during processing.", + advanced=True, ), IO.Float.Input( "face_enhancement_creativity", @@ -67,6 +81,7 @@ class TopazImageEnhance(IO.ComfyNode): display_mode=IO.NumberDisplay.number, optional=True, tooltip="Set the creativity level for face enhancement.", + advanced=True, ), IO.Float.Input( "face_enhancement_strength", @@ -77,6 +92,7 @@ class TopazImageEnhance(IO.ComfyNode): display_mode=IO.NumberDisplay.number, optional=True, tooltip="Controls how sharp enhanced faces are relative to the background.", + advanced=True, ), IO.Boolean.Input( "crop_to_fill", @@ -84,6 +100,7 @@ class TopazImageEnhance(IO.ComfyNode): optional=True, tooltip="By default, the image is letterboxed when the output aspect ratio differs. " "Enable to crop the image to fill the output dimensions.", + advanced=True, ), IO.Int.Input( "output_width", @@ -94,6 +111,7 @@ class TopazImageEnhance(IO.ComfyNode): display_mode=IO.NumberDisplay.number, optional=True, tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).", + advanced=True, ), IO.Int.Input( "output_height", @@ -104,6 +122,7 @@ class TopazImageEnhance(IO.ComfyNode): display_mode=IO.NumberDisplay.number, optional=True, tooltip="Zero value means to output in the same height as original or output width.", + advanced=True, ), IO.Int.Input( "creativity", @@ -119,12 +138,14 @@ class TopazImageEnhance(IO.ComfyNode): default=True, optional=True, tooltip="Preserve subjects' facial identity.", + advanced=True, ), IO.Boolean.Input( "color_preservation", default=True, optional=True, tooltip="Preserve the original colors.", + advanced=True, ), ], outputs=[ @@ -142,7 +163,7 @@ class TopazImageEnhance(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str = "", subject_detection: str = "All", face_enhancement: bool = True, @@ -157,12 +178,14 @@ class TopazImageEnhance(IO.ComfyNode): ) -> IO.NodeOutput: if get_number_of_images(image) != 1: raise ValueError("Only one input image is supported.") - download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png") + download_url = await upload_images_to_comfyapi( + cls, image, max_images=1, mime_type="image/png", total_pixels=4096 * 4096 + ) initial_response = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), - response_model=topaz_api.ImageAsyncTaskResponse, - data=topaz_api.ImageEnhanceRequest( + response_model=ImageAsyncTaskResponse, + data=ImageEnhanceRequest( model=model, prompt=prompt, subject_detection=subject_detection, @@ -184,19 +207,18 @@ class TopazImageEnhance(IO.ComfyNode): await poll_op( cls, poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"), - response_model=topaz_api.ImageStatusResponse, + response_model=ImageStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: x.credits * 0.08, poll_interval=8.0, - max_poll_attempts=160, estimated_duration=60, ) results = await sync_op( cls, ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"), - response_model=topaz_api.ImageDownloadResponse, + response_model=ImageDownloadResponse, monitor_progress=False, ) return IO.NodeOutput(await download_url_to_image_tensor(results.download_url)) @@ -214,16 +236,17 @@ class TopazVideoEnhance(IO.ComfyNode): IO.Video.Input("video"), IO.Boolean.Input("upscaler_enabled", default=True), IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), - IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), + IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input( "upscaler_creativity", options=["low", "middle", "high"], default="low", tooltip="Creativity level (applies only to Starlight (Astra) Creative).", optional=True, + advanced=True, ), IO.Boolean.Input("interpolation_enabled", default=False, optional=True), - IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True), + IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True, advanced=True), IO.Int.Input( "interpolation_slowmo", default=1, @@ -233,6 +256,7 @@ class TopazVideoEnhance(IO.ComfyNode): tooltip="Slow-motion factor applied to the input video. " "For example, 2 makes the output twice as slow and doubles the duration.", optional=True, + advanced=True, ), IO.Int.Input( "interpolation_frame_rate", @@ -248,6 +272,7 @@ class TopazVideoEnhance(IO.ComfyNode): default=False, tooltip="Analyze the input for duplicate frames and remove them.", optional=True, + advanced=True, ), IO.Float.Input( "interpolation_duplicate_threshold", @@ -258,6 +283,7 @@ class TopazVideoEnhance(IO.ComfyNode): display_mode=IO.NumberDisplay.number, tooltip="Detection sensitivity for duplicate frames.", optional=True, + advanced=True, ), IO.Combo.Input( "dynamic_compression_level", @@ -265,6 +291,7 @@ class TopazVideoEnhance(IO.ComfyNode): default="Low", tooltip="CQP level.", optional=True, + advanced=True, ), ], outputs=[ @@ -306,10 +333,35 @@ class TopazVideoEnhance(IO.ComfyNode): target_frame_rate = src_frame_rate filters = [] if upscaler_enabled: - target_width = UPSCALER_VALUES_MAP[upscaler_resolution] - target_height = UPSCALER_VALUES_MAP[upscaler_resolution] + if "1080p" in upscaler_resolution: + target_pixel_p = 1080 + max_long_side = 1920 + else: + target_pixel_p = 2160 + max_long_side = 3840 + ar = src_width / src_height + if src_width >= src_height: + # Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width + target_height = target_pixel_p + target_width = int(target_height * ar) + # Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs) + if target_width > max_long_side: + target_width = max_long_side + target_height = int(target_width / ar) + else: + # Portrait; Attempt to set width to target (e.g., 2160), calculate height + target_width = target_pixel_p + target_height = int(target_width / ar) + # Check if height exceeds standard bounds + if target_height > max_long_side: + target_height = max_long_side + target_width = int(target_height * ar) + if target_width % 2 != 0: + target_width += 1 + if target_height % 2 != 0: + target_height += 1 filters.append( - topaz_api.VideoEnhancementFilter( + VideoEnhancementFilter( model=UPSCALER_MODELS_MAP[upscaler_model], creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), @@ -318,7 +370,7 @@ class TopazVideoEnhance(IO.ComfyNode): if interpolation_enabled: target_frame_rate = interpolation_frame_rate filters.append( - topaz_api.VideoFrameInterpolationFilter( + VideoFrameInterpolationFilter( model=interpolation_model, slowmo=interpolation_slowmo, fps=interpolation_frame_rate, @@ -329,19 +381,19 @@ class TopazVideoEnhance(IO.ComfyNode): initial_res = await sync_op( cls, ApiEndpoint(path="/proxy/topaz/video/", method="POST"), - response_model=topaz_api.CreateVideoResponse, - data=topaz_api.CreateVideoRequest( - source=topaz_api.CreateCreateVideoRequestSource( + response_model=CreateVideoResponse, + data=CreateVideoRequest( + source=CreateVideoRequestSource( container="mp4", size=get_fs_object_size(src_video_stream), duration=int(duration_sec), frameCount=video.get_frame_count(), frameRate=src_frame_rate, - resolution=topaz_api.Resolution(width=src_width, height=src_height), + resolution=Resolution(width=src_width, height=src_height), ), filters=filters, - output=topaz_api.OutputInformationVideo( - resolution=topaz_api.Resolution(width=target_width, height=target_height), + output=OutputInformationVideo( + resolution=Resolution(width=target_width, height=target_height), frameRate=target_frame_rate, audioCodec="AAC", audioTransfer="Copy", @@ -357,7 +409,7 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/accept", method="PATCH", ), - response_model=topaz_api.VideoAcceptResponse, + response_model=VideoAcceptResponse, wait_label="Preparing upload", final_label_on_success="Upload started", ) @@ -380,10 +432,10 @@ class TopazVideoEnhance(IO.ComfyNode): path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", method="PATCH", ), - response_model=topaz_api.VideoCompleteUploadResponse, - data=topaz_api.VideoCompleteUploadRequest( + response_model=VideoCompleteUploadResponse, + data=VideoCompleteUploadRequest( uploadResults=[ - topaz_api.VideoCompleteUploadRequestPart( + VideoCompleteUploadRequestPart( partNum=1, eTag=upload_etag, ), @@ -395,7 +447,7 @@ class TopazVideoEnhance(IO.ComfyNode): final_response = await poll_op( cls, ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), - response_model=topaz_api.VideoStatusResponse, + response_model=VideoStatusResponse, status_extractor=lambda x: x.status, progress_extractor=lambda x: getattr(x, "progress", 0), price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index bd3c24fb3..9f4298dce 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -1,11 +1,7 @@ -import os -from typing import Optional - -import torch from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension -from comfy_api_nodes.apis.tripo_api import ( +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.tripo import ( TripoAnimateRetargetRequest, TripoAnimateRigRequest, TripoConvertModelRequest, @@ -26,12 +22,11 @@ from comfy_api_nodes.apis.tripo_api import ( ) from comfy_api_nodes.util import ( ApiEndpoint, - download_url_as_bytesio, + download_url_to_file_3d, poll_op, sync_op, upload_images_to_comfyapi, ) -from folder_paths import get_output_directory def get_model_url_from_response(response: TripoTaskResponse) -> str: @@ -45,7 +40,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: async def poll_until_finished( node_cls: type[IO.ComfyNode], response: TripoTaskResponse, - average_duration: Optional[int] = None, + average_duration: int | None = None, ) -> IO.NodeOutput: """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response.""" if response.code != 0: @@ -69,12 +64,8 @@ async def poll_until_finished( ) if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = await download_url_as_bytesio(url) - # Save the downloaded model file - model_file = f"tripo_model_{task_id}.glb" - with open(os.path.join(get_output_directory(), model_file), "wb") as f: - f.write(bytesio.getvalue()) - return IO.NodeOutput(model_file, task_id) + file_glb = await download_url_to_file_3d(url, "glb", task_id=task_id) + return IO.NodeOutput(f"{task_id}.glb", task_id, file_glb) raise RuntimeError(f"Failed to generate mesh: {response_poll}") @@ -98,17 +89,18 @@ class TripoTextToModelNode(IO.ComfyNode): IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), IO.Boolean.Input("texture", default=True, optional=True), IO.Boolean.Input("pbr", default=True, optional=True), - IO.Int.Input("image_seed", default=42, optional=True), - IO.Int.Input("model_seed", default=42, optional=True), - IO.Int.Input("texture_seed", default=42, optional=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), - IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True), - IO.Boolean.Input("quad", default=False, optional=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("image_seed", default=42, optional=True, advanced=True), + IO.Int.Input("model_seed", default=42, optional=True, advanced=True), + IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), + IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True, advanced=True), + IO.Boolean.Input("quad", default=False, optional=True, advanced=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -117,24 +109,56 @@ class TripoTextToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "style", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $style := widgets.style; + $hasStyle := ($style != "" and $style != "none"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 20 : ($withTexture ? 20 : 10); + $credits := + $baseCredits + + ($hasStyle ? 5 : 0) + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod async def execute( cls, prompt: str, - negative_prompt: Optional[str] = None, + negative_prompt: str | None = None, model_version=None, - style: Optional[str] = None, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - image_seed: Optional[int] = None, - model_seed: Optional[int] = None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - geometry_quality: Optional[str] = None, - face_limit: Optional[int] = None, - quad: Optional[bool] = None, + style: str | None = None, + texture: bool | None = None, + pbr: bool | None = None, + image_seed: int | None = None, + model_seed: int | None = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + geometry_quality: str | None = None, + face_limit: int | None = None, + quad: bool | None = None, ) -> IO.NodeOutput: style_enum = None if style == "None" else style if not prompt: @@ -155,7 +179,7 @@ class TripoTextToModelNode(IO.ComfyNode): model_seed=model_seed, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, geometry_quality=geometry_quality, auto_size=True, quad=quad, @@ -186,22 +210,23 @@ class TripoImageToModelNode(IO.ComfyNode): IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), IO.Boolean.Input("texture", default=True, optional=True), IO.Boolean.Input("pbr", default=True, optional=True), - IO.Int.Input("model_seed", default=42, optional=True), + IO.Int.Input("model_seed", default=42, optional=True, advanced=True), IO.Combo.Input( - "orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True + "orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True, advanced=True ), - IO.Int.Input("texture_seed", default=42, optional=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True ), - IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), - IO.Boolean.Input("quad", default=False, optional=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True), + IO.Boolean.Input("quad", default=False, optional=True, advanced=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -210,24 +235,56 @@ class TripoImageToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "style", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $style := widgets.style; + $hasStyle := ($style != "" and $style != "none"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 30 : ($withTexture ? 30 : 20); + $credits := + $baseCredits + + ($hasStyle ? 5 : 0) + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod async def execute( cls, - image: torch.Tensor, - model_version: Optional[str] = None, - style: Optional[str] = None, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - model_seed: Optional[int] = None, + image: Input.Image, + model_version: str | None = None, + style: str | None = None, + texture: bool | None = None, + pbr: bool | None = None, + model_seed: int | None = None, orientation=None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - geometry_quality: Optional[str] = None, - texture_alignment: Optional[str] = None, - face_limit: Optional[int] = None, - quad: Optional[bool] = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + geometry_quality: str | None = None, + texture_alignment: str | None = None, + face_limit: int | None = None, + quad: bool | None = None, ) -> IO.NodeOutput: style_enum = None if style == "None" else style if image is None: @@ -255,7 +312,7 @@ class TripoImageToModelNode(IO.ComfyNode): texture_alignment=texture_alignment, texture_seed=texture_seed, texture_quality=texture_quality, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, auto_size=True, quad=quad, ), @@ -290,22 +347,24 @@ class TripoMultiviewToModelNode(IO.ComfyNode): options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True, + advanced=True, ), IO.Boolean.Input("texture", default=True, optional=True), IO.Boolean.Input("pbr", default=True, optional=True), - IO.Int.Input("model_seed", default=42, optional=True), - IO.Int.Input("texture_seed", default=42, optional=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("model_seed", default=42, optional=True, advanced=True), + IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True ), - IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), - IO.Boolean.Input("quad", default=False, optional=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True), + IO.Boolean.Input("quad", default=False, optional=True, advanced=True), + IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -314,26 +373,54 @@ class TripoMultiviewToModelNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model_version", + "texture", + "pbr", + "quad", + "texture_quality", + "geometry_quality", + ], + ), + expr=""" + ( + $isV14 := $contains(widgets.model_version,"v1.4"); + $withTexture := widgets.texture or widgets.pbr; + $isHdTexture := (widgets.texture_quality = "detailed"); + $isDetailedGeometry := (widgets.geometry_quality = "detailed"); + $baseCredits := + $isV14 ? 30 : ($withTexture ? 30 : 20); + $credits := + $baseCredits + + (widgets.quad ? 5 : 0) + + ($isHdTexture ? 10 : 0) + + ($isDetailedGeometry ? 20 : 0); + {"type":"usd","usd": $round($credits * 0.01, 2)} + ) + """, + ), ) @classmethod async def execute( cls, - image: torch.Tensor, - image_left: Optional[torch.Tensor] = None, - image_back: Optional[torch.Tensor] = None, - image_right: Optional[torch.Tensor] = None, - model_version: Optional[str] = None, - orientation: Optional[str] = None, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - model_seed: Optional[int] = None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - geometry_quality: Optional[str] = None, - texture_alignment: Optional[str] = None, - face_limit: Optional[int] = None, - quad: Optional[bool] = None, + image: Input.Image, + image_left: Input.Image | None = None, + image_back: Input.Image | None = None, + image_right: Input.Image | None = None, + model_version: str | None = None, + orientation: str | None = None, + texture: bool | None = None, + pbr: bool | None = None, + model_seed: int | None = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + geometry_quality: str | None = None, + texture_alignment: str | None = None, + face_limit: int | None = None, + quad: bool | None = None, ) -> IO.NodeOutput: if image is None: raise RuntimeError("front image for multiview is required") @@ -369,7 +456,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): texture_quality=texture_quality, geometry_quality=geometry_quality, texture_alignment=texture_alignment, - face_limit=face_limit, + face_limit=face_limit if face_limit != -1 else None, quad=quad, ), ) @@ -388,15 +475,16 @@ class TripoTextureNode(IO.ComfyNode): IO.Custom("MODEL_TASK_ID").Input("model_task_id"), IO.Boolean.Input("texture", default=True, optional=True), IO.Boolean.Input("pbr", default=True, optional=True), - IO.Int.Input("texture_seed", default=42, optional=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -405,17 +493,26 @@ class TripoTextureNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["texture_quality"]), + expr=""" + ( + $tq := widgets.texture_quality; + {"type":"usd","usd": ($contains($tq,"detailed") ? 0.2 : 0.1)} + ) + """, + ), ) @classmethod async def execute( cls, model_task_id, - texture: Optional[bool] = None, - pbr: Optional[bool] = None, - texture_seed: Optional[int] = None, - texture_quality: Optional[str] = None, - texture_alignment: Optional[str] = None, + texture: bool | None = None, + pbr: bool | None = None, + texture_seed: int | None = None, + texture_quality: str | None = None, + texture_alignment: str | None = None, ) -> IO.NodeOutput: response = await sync_op( cls, @@ -446,8 +543,9 @@ class TripoRefineNode(IO.ComfyNode): IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -456,6 +554,9 @@ class TripoRefineNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.3}""", + ), ) @classmethod @@ -479,8 +580,9 @@ class TripoRigNode(IO.ComfyNode): category="api node/3d/Tripo", inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -489,6 +591,9 @@ class TripoRigNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.25}""", + ), ) @classmethod @@ -535,8 +640,9 @@ class TripoRetargetNode(IO.ComfyNode): ), ], outputs=[ - IO.String.Output(display_name="model_file"), + IO.String.Output(display_name="model_file"), # for backward compatibility only IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"), + IO.File3DGLB.Output(display_name="GLB"), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -545,6 +651,9 @@ class TripoRetargetNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.1}""", + ), ) @classmethod @@ -574,13 +683,14 @@ class TripoConversionNode(IO.ComfyNode): inputs=[ IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), - IO.Boolean.Input("quad", default=False, optional=True), + IO.Boolean.Input("quad", default=False, optional=True, advanced=True), IO.Int.Input( "face_limit", default=-1, min=-1, max=2000000, optional=True, + advanced=True, ), IO.Int.Input( "texture_size", @@ -588,47 +698,53 @@ class TripoConversionNode(IO.ComfyNode): min=128, max=4096, optional=True, + advanced=True, ), IO.Combo.Input( "texture_format", options=["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], default="JPEG", optional=True, + advanced=True, ), - IO.Boolean.Input("force_symmetry", default=False, optional=True), - IO.Boolean.Input("flatten_bottom", default=False, optional=True), + IO.Boolean.Input("force_symmetry", default=False, optional=True, advanced=True), + IO.Boolean.Input("flatten_bottom", default=False, optional=True, advanced=True), IO.Float.Input( "flatten_bottom_threshold", default=0.0, min=0.0, max=1.0, optional=True, + advanced=True, ), - IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True), + IO.Boolean.Input("pivot_to_center_bottom", default=False, optional=True, advanced=True), IO.Float.Input( "scale_factor", default=1.0, min=0.0, optional=True, + advanced=True, ), - IO.Boolean.Input("with_animation", default=False, optional=True), - IO.Boolean.Input("pack_uv", default=False, optional=True), - IO.Boolean.Input("bake", default=False, optional=True), - IO.String.Input("part_names", default="", optional=True), # comma-separated list + IO.Boolean.Input("with_animation", default=False, optional=True, advanced=True), + IO.Boolean.Input("pack_uv", default=False, optional=True, advanced=True), + IO.Boolean.Input("bake", default=False, optional=True, advanced=True), + IO.String.Input("part_names", default="", optional=True, advanced=True), # comma-separated list IO.Combo.Input( "fbx_preset", options=["blender", "mixamo", "3dsmax"], default="blender", optional=True, + advanced=True, ), - IO.Boolean.Input("export_vertex_colors", default=False, optional=True), + IO.Boolean.Input("export_vertex_colors", default=False, optional=True, advanced=True), IO.Combo.Input( "export_orientation", options=["align_image", "default"], default="default", optional=True, + advanced=True, ), - IO.Boolean.Input("animate_in_place", default=False, optional=True), + IO.Boolean.Input("animate_in_place", default=False, optional=True, advanced=True), ], outputs=[], hidden=[ @@ -638,6 +754,60 @@ class TripoConversionNode(IO.ComfyNode): ], is_api_node=True, is_output_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "quad", + "face_limit", + "texture_size", + "texture_format", + "force_symmetry", + "flatten_bottom", + "flatten_bottom_threshold", + "pivot_to_center_bottom", + "scale_factor", + "with_animation", + "pack_uv", + "bake", + "part_names", + "fbx_preset", + "export_vertex_colors", + "export_orientation", + "animate_in_place", + ], + ), + expr=""" + ( + $face := (widgets.face_limit != null) ? widgets.face_limit : -1; + $texSize := (widgets.texture_size != null) ? widgets.texture_size : 4096; + $flatThresh := (widgets.flatten_bottom_threshold != null) ? widgets.flatten_bottom_threshold : 0; + $scale := (widgets.scale_factor != null) ? widgets.scale_factor : 1; + $texFmt := (widgets.texture_format != "" ? widgets.texture_format : "jpeg"); + $part := widgets.part_names; + $fbx := (widgets.fbx_preset != "" ? widgets.fbx_preset : "blender"); + $orient := (widgets.export_orientation != "" ? widgets.export_orientation : "default"); + $advanced := + widgets.quad or + widgets.force_symmetry or + widgets.flatten_bottom or + widgets.pivot_to_center_bottom or + widgets.with_animation or + widgets.pack_uv or + widgets.bake or + widgets.export_vertex_colors or + widgets.animate_in_place or + ($face != -1) or + ($texSize != 4096) or + ($flatThresh != 0) or + ($scale != 1) or + ($texFmt != "jpeg") or + ($part != "") or + ($fbx != "blender") or + ($orient != "default"); + {"type":"usd","usd": ($advanced ? 0.1 : 0.05)} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index e165b8380..13fc1cc36 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -4,7 +4,7 @@ from io import BytesIO from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input, InputImpl -from comfy_api_nodes.apis.veo_api import ( +from comfy_api_nodes.apis.veo import ( VeoGenVidPollRequest, VeoGenVidPollResponse, VeoGenVidRequest, @@ -81,6 +81,7 @@ class VeoVideoGenerationNode(IO.ComfyNode): default=True, tooltip="Whether to enhance the prompt with AI assistance", optional=True, + advanced=True, ), IO.Combo.Input( "person_generation", @@ -88,6 +89,7 @@ class VeoVideoGenerationNode(IO.ComfyNode): default="ALLOW", tooltip="Whether to allow generating people in the video", optional=True, + advanced=True, ), IO.Int.Input( "seed", @@ -122,6 +124,10 @@ class VeoVideoGenerationNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration_seconds"]), + expr="""{"type":"usd","usd": 0.5 * widgets.duration_seconds}""", + ), ) @classmethod @@ -168,6 +174,8 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Only add generateAudio for Veo 3 models if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio + # force "enhance_prompt" to True for Veo3 models + parameters["enhancePrompt"] = True initial_response = await sync_op( cls, @@ -291,8 +299,9 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Boolean.Input( "enhance_prompt", default=True, - tooltip="Whether to enhance the prompt with AI assistance", + tooltip="This parameter is deprecated and ignored.", optional=True, + advanced=True, ), IO.Combo.Input( "person_generation", @@ -300,6 +309,7 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): default="ALLOW", tooltip="Whether to allow generating people in the video", optional=True, + advanced=True, ), IO.Int.Input( "seed", @@ -345,6 +355,20 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]), + expr=""" + ( + $m := widgets.model; + $a := widgets.generate_audio; + ($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate")) + ? {"type":"usd","usd": ($a ? 1.2 : 0.8)} + : ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate")) + ? {"type":"usd","usd": ($a ? 3.2 : 1.6)} + : {"type":"range_usd","min_usd":0.8,"max_usd":3.2} + ) + """, + ), ) @@ -418,6 +442,30 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]), + expr=""" + ( + $prices := { + "veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 }, + "veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 } + }; + $m := widgets.model; + $ga := (widgets.generate_audio = "true"); + $seconds := widgets.duration; + $modelKey := + $contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" : + $contains($m, "veo-3.1-generate") ? "veo-3.1-generate" : + ""; + $audioKey := $ga ? "audio" : "no_audio"; + $modelPrices := $lookup($prices, $modelKey); + $pps := $lookup($modelPrices, $audioKey); + ($pps != null) + ? {"type":"usd","usd": $pps * $seconds} + : {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2} + ) + """, + ), ) @classmethod diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 7a679f0d9..f04407eb5 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -1,22 +1,30 @@ -import logging -from enum import Enum -from typing import Literal, Optional, TypeVar - -import torch -from pydantic import BaseModel, Field from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.vidu import ( + FrameSetting, + SubjectReference, + TaskCreationRequest, + TaskCreationResponse, + TaskExtendCreationRequest, + TaskMultiFrameCreationRequest, + TaskResult, + TaskStatusResponse, +) from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_video_output, get_number_of_images, poll_op, sync_op, + upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_image_aspect_ratio, validate_image_dimensions, validate_images_aspect_ratio_closeness, + validate_string, + validate_video_duration, ) VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" @@ -25,98 +33,35 @@ VIDU_REFERENCE_VIDEO = "/proxy/vidu/reference2video" VIDU_START_END_VIDEO = "/proxy/vidu/start-end2video" VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" -R = TypeVar("R") - - -class VideoModelName(str, Enum): - vidu_q1 = "viduq1" - - -class AspectRatio(str, Enum): - r_16_9 = "16:9" - r_9_16 = "9:16" - r_1_1 = "1:1" - - -class Resolution(str, Enum): - r_1080p = "1080p" - - -class MovementAmplitude(str, Enum): - auto = "auto" - small = "small" - medium = "medium" - large = "large" - - -class TaskCreationRequest(BaseModel): - model: VideoModelName = VideoModelName.vidu_q1 - prompt: Optional[str] = Field(None, max_length=1500) - duration: Optional[Literal[5]] = 5 - seed: Optional[int] = Field(0, ge=0, le=2147483647) - aspect_ratio: Optional[AspectRatio] = AspectRatio.r_16_9 - resolution: Optional[Resolution] = Resolution.r_1080p - movement_amplitude: Optional[MovementAmplitude] = MovementAmplitude.auto - images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") - - -class TaskCreationResponse(BaseModel): - task_id: str = Field(...) - state: str = Field(...) - created_at: str = Field(...) - code: Optional[int] = Field(None, description="Error code") - - -class TaskResult(BaseModel): - id: str = Field(..., description="Creation id") - url: str = Field(..., description="The URL of the generated results, valid for one hour") - cover_url: str = Field(..., description="The cover URL of the generated results, valid for one hour") - - -class TaskStatusResponse(BaseModel): - state: str = Field(...) - err_code: Optional[str] = Field(None) - creations: list[TaskResult] = Field(..., description="Generated results") - - -def get_video_url_from_response(response) -> Optional[str]: - if response.creations: - return response.creations[0].url - return None - - -def get_video_from_response(response) -> TaskResult: - if not response.creations: - error_msg = f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" - logging.info(error_msg) - raise RuntimeError(error_msg) - logging.info("Vidu task %s succeeded. Video URL: %s", response.creations[0].id, response.creations[0].url) - return response.creations[0] - async def execute_task( cls: type[IO.ComfyNode], vidu_endpoint: str, - payload: TaskCreationRequest, - estimated_duration: int, -) -> R: - response = await sync_op( + payload: TaskCreationRequest | TaskExtendCreationRequest | TaskMultiFrameCreationRequest, + max_poll_attempts: int = 320, +) -> list[TaskResult]: + task_creation_response = await sync_op( cls, endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), response_model=TaskCreationResponse, data=payload, ) - if response.state == "failed": - error_msg = f"Vidu request failed. Code: {response.code}" - logging.error(error_msg) - raise RuntimeError(error_msg) - return await poll_op( + if task_creation_response.state == "failed": + raise RuntimeError(f"Vidu request failed. Code: {task_creation_response.code}") + response = await poll_op( cls, - ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), + ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % task_creation_response.task_id), response_model=TaskStatusResponse, status_extractor=lambda r: r.state, - estimated_duration=estimated_duration, + progress_extractor=lambda r: r.progress, + price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None, + max_poll_attempts=max_poll_attempts, ) + if not response.creations: + raise RuntimeError( + f"Vidu request does not contain results. State: {response.state}, Error Code: {response.err_code}" + ) + return response.creations class ViduTextToVideoNode(IO.ComfyNode): @@ -127,14 +72,9 @@ class ViduTextToVideoNode(IO.ComfyNode): node_id="ViduTextToVideoNode", display_name="Vidu Text To Video Generation", category="api node/video/Vidu", - description="Generate video from text prompt", + description="Generate video from a text prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.String.Input( "prompt", multiline=True, @@ -163,24 +103,23 @@ class ViduTextToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "aspect_ratio", - options=AspectRatio, - default=AspectRatio.r_16_9, + options=["16:9", "9:16", "1:1"], tooltip="The aspect ratio of the output video", optional=True, ), IO.Combo.Input( "resolution", - options=Resolution, - default=Resolution.r_1080p, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, + advanced=True, ), IO.Combo.Input( "movement_amplitude", - options=MovementAmplitude, - default=MovementAmplitude.auto, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, + advanced=True, ), ], outputs=[ @@ -192,6 +131,9 @@ class ViduTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod @@ -208,7 +150,7 @@ class ViduTextToVideoNode(IO.ComfyNode): if not prompt: raise ValueError("The prompt field is required and cannot be empty.") payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -216,8 +158,8 @@ class ViduTextToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduImageToVideoNode(IO.ComfyNode): @@ -230,12 +172,7 @@ class ViduImageToVideoNode(IO.ComfyNode): category="api node/video/Vidu", description="Generate video from image and optional prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "image", tooltip="An image to be used as the start frame of the generated video", @@ -270,17 +207,17 @@ class ViduImageToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=Resolution, - default=Resolution.r_1080p, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, + advanced=True, ), IO.Combo.Input( "movement_amplitude", - options=MovementAmplitude, - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, + advanced=True, ), ], outputs=[ @@ -292,13 +229,16 @@ class ViduImageToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, duration: int, seed: int, @@ -309,7 +249,7 @@ class ViduImageToVideoNode(IO.ComfyNode): raise ValueError("Only one input image is allowed.") validate_image_aspect_ratio(image, (1, 4), (4, 1)) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -322,8 +262,8 @@ class ViduImageToVideoNode(IO.ComfyNode): max_images=1, mime_type="image/png", ) - results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduReferenceVideoNode(IO.ComfyNode): @@ -334,14 +274,9 @@ class ViduReferenceVideoNode(IO.ComfyNode): node_id="ViduReferenceVideoNode", display_name="Vidu Reference To Video Generation", category="api node/video/Vidu", - description="Generate video from multiple images and prompt", + description="Generate video from multiple images and a prompt", inputs=[ - IO.Combo.Input( - "model", - options=VideoModelName, - default=VideoModelName.vidu_q1, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "images", tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).", @@ -374,24 +309,23 @@ class ViduReferenceVideoNode(IO.ComfyNode): ), IO.Combo.Input( "aspect_ratio", - options=AspectRatio, - default=AspectRatio.r_16_9, + options=["16:9", "9:16", "1:1"], tooltip="The aspect ratio of the output video", optional=True, ), IO.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, + advanced=True, ), IO.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, + advanced=True, ), ], outputs=[ @@ -403,13 +337,16 @@ class ViduReferenceVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod async def execute( cls, model: str, - images: torch.Tensor, + images: Input.Image, prompt: str, duration: int, seed: int, @@ -426,7 +363,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): validate_image_aspect_ratio(image, (1, 4), (4, 1)) validate_image_dimensions(image, min_width=128, min_height=128) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -440,8 +377,8 @@ class ViduReferenceVideoNode(IO.ComfyNode): max_images=7, mime_type="image/png", ) - results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduStartEndToVideoNode(IO.ComfyNode): @@ -454,12 +391,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): category="api node/video/Vidu", description="Generate a video from start and end frames and a prompt", inputs=[ - IO.Combo.Input( - "model", - options=[model.value for model in VideoModelName], - default=VideoModelName.vidu_q1.value, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), IO.Image.Input( "first_frame", tooltip="Start frame", @@ -497,17 +429,17 @@ class ViduStartEndToVideoNode(IO.ComfyNode): ), IO.Combo.Input( "resolution", - options=[model.value for model in Resolution], - default=Resolution.r_1080p.value, + options=["1080p"], tooltip="Supported values may vary by model & duration", optional=True, + advanced=True, ), IO.Combo.Input( "movement_amplitude", - options=[model.value for model in MovementAmplitude], - default=MovementAmplitude.auto.value, + options=["auto", "small", "medium", "large"], tooltip="The movement amplitude of objects in the frame", optional=True, + advanced=True, ), ], outputs=[ @@ -519,14 +451,17 @@ class ViduStartEndToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.4}""", + ), ) @classmethod async def execute( cls, model: str, - first_frame: torch.Tensor, - end_frame: torch.Tensor, + first_frame: Input.Image, + end_frame: Input.Image, prompt: str, duration: int, seed: int, @@ -535,7 +470,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) payload = TaskCreationRequest( - model_name=model, + model=model, prompt=prompt, duration=duration, seed=seed, @@ -546,8 +481,1227 @@ class ViduStartEndToVideoNode(IO.ComfyNode): (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ] - results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) - return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2TextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2TextToVideoNode", + display_name="Vidu2 Text-to-Video Generation", + category="api node/video/Vidu", + description="Generate video from a text prompt", + inputs=[ + IO.Combo.Input("model", options=["viduq2"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation, with a maximum length of 2000 characters.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "3:4", "4:3", "1:1"]), + IO.Combo.Input("resolution", options=["720p", "1080p"], advanced=True), + IO.Boolean.Input( + "background_music", + default=False, + tooltip="Whether to add background music to the generated video.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $is1080 := widgets.resolution = "1080p"; + $base := $is1080 ? 0.1 : 0.075; + $perSec := $is1080 ? 0.05 : 0.025; + {"type":"usd","usd": $base + $perSec * (widgets.duration - 1)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + background_music: bool, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + results = await execute_task( + cls, + VIDU_TEXT_TO_VIDEO, + TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + bgm=background_music, + ), + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2ImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2ImageToVideoNode", + display_name="Vidu2 Image-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from an image and an optional prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), + IO.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for video generation (max 2000 characters).", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + advanced=True, + ), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $m := widgets.model; + $d := widgets.duration; + $is1080 := widgets.resolution = "1080p"; + $contains($m, "pro-fast") + ? ( + $base := $is1080 ? 0.08 : 0.04; + $perSec := $is1080 ? 0.02 : 0.01; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "pro") + ? ( + $base := $is1080 ? 0.275 : 0.075; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "turbo") + ? ( + $is1080 + ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} + : ( + $d <= 1 ? {"type":"usd","usd": 0.04} + : $d <= 2 ? {"type":"usd","usd": 0.05} + : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} + ) + ) + : {"type":"usd","usd": 0.04} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) > 1: + raise ValueError("Only one input image is allowed.") + validate_image_aspect_ratio(image, (1, 4), (4, 1)) + validate_string(prompt, max_length=2000) + results = await execute_task( + cls, + VIDU_IMAGE_TO_VIDEO, + TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + images=await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type="image/png", + ), + ), + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2ReferenceVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2ReferenceVideoNode", + display_name="Vidu2 Reference-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from multiple reference images and a prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2"]), + IO.Autogrow.Input( + "subjects", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("reference_images"), + names=["subject1", "subject2", "subject3", "subject4", "subject5", "subject6", "subject7"], + min=1, + ), + tooltip="For each subject, provide up to 3 reference images (7 images total across all subjects). " + "Reference them in prompts via @subject{subject_id}.", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="When enabled, the video will include generated speech and background music " + "based on the prompt.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled video will contain generated speech and background music based on the prompt.", + advanced=True, + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=10, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "4:3", "3:4", "1:1"]), + IO.Combo.Input("resolution", options=["720p", "1080p"], advanced=True), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["audio", "duration", "resolution"]), + expr=""" + ( + $is1080 := widgets.resolution = "1080p"; + $base := $is1080 ? 0.375 : 0.125; + $perSec := $is1080 ? 0.05 : 0.025; + $audioCost := widgets.audio = true ? 0.075 : 0; + {"type":"usd","usd": $base + $perSec * (widgets.duration - 1) + $audioCost} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + subjects: IO.Autogrow.Type, + prompt: str, + audio: bool, + duration: int, + seed: int, + aspect_ratio: str, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + total_images = 0 + for i in subjects: + if get_number_of_images(subjects[i]) > 3: + raise ValueError("Maximum number of images per subject is 3.") + for im in subjects[i]: + total_images += 1 + validate_image_aspect_ratio(im, (1, 4), (4, 1)) + validate_image_dimensions(im, min_width=128, min_height=128) + if total_images > 7: + raise ValueError("Too many reference images; the maximum allowed is 7.") + subjects_param: list[SubjectReference] = [] + for i in subjects: + subjects_param.append( + SubjectReference( + id=i, + images=await upload_images_to_comfyapi( + cls, + subjects[i], + max_images=3, + mime_type="image/png", + wait_label=f"Uploading reference images for {i}", + ), + ), + ) + payload = TaskCreationRequest( + model=model, + prompt=prompt, + audio=audio, + duration=duration, + seed=seed, + aspect_ratio=aspect_ratio, + resolution=resolution, + movement_amplitude=movement_amplitude, + subjects=subjects_param, + ) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu2StartEndToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu2StartEndToVideoNode", + display_name="Vidu2 Start/End Frame-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from a start frame, an end frame, and a prompt.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), + IO.Image.Input("first_frame"), + IO.Image.Input("end_frame"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt description (max 2000 characters).", + ), + IO.Int.Input( + "duration", + default=5, + min=2, + max=8, + step=1, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("resolution", options=["720p", "1080p"], advanced=True), + IO.Combo.Input( + "movement_amplitude", + options=["auto", "small", "medium", "large"], + tooltip="The movement amplitude of objects in the frame.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"]), + expr=""" + ( + $m := widgets.model; + $d := widgets.duration; + $is1080 := widgets.resolution = "1080p"; + $contains($m, "pro-fast") + ? ( + $base := $is1080 ? 0.08 : 0.04; + $perSec := $is1080 ? 0.02 : 0.01; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "pro") + ? ( + $base := $is1080 ? 0.275 : 0.075; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : $contains($m, "turbo") + ? ( + $is1080 + ? {"type":"usd","usd": 0.175 + 0.05 * ($d - 1)} + : ( + $d <= 2 ? {"type":"usd","usd": 0.05} + : {"type":"usd","usd": 0.05 + 0.05 * ($d - 2)} + ) + ) + : {"type":"usd","usd": 0.04} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + first_frame: Input.Image, + end_frame: Input.Image, + prompt: str, + duration: int, + seed: int, + resolution: str, + movement_amplitude: str, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2000) + if get_number_of_images(first_frame) > 1: + raise ValueError("Only one input image is allowed for `first_frame`.") + if get_number_of_images(end_frame) > 1: + raise ValueError("Only one input image is allowed for `end_frame`.") + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + payload = TaskCreationRequest( + model=model, + prompt=prompt, + duration=duration, + seed=seed, + resolution=resolution, + movement_amplitude=movement_amplitude, + images=[ + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + for frame in (first_frame, end_frame) + ], + ) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class ViduExtendVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduExtendVideoNode", + display_name="Vidu Video Extension", + category="api node/video/Vidu", + description="Extend an existing video by generating additional frames.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq2-pro", + [ + IO.Int.Input( + "duration", + default=4, + min=1, + max=7, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the extended video in seconds.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + ], + ), + IO.DynamicCombo.Option( + "viduq2-turbo", + [ + IO.Int.Input( + "duration", + default=4, + min=1, + max=7, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the extended video in seconds.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + ], + ), + ], + tooltip="Model to use for video extension.", + ), + IO.Video.Input( + "video", + tooltip="The source video to extend.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for the extended video (max 2000 characters).", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Image.Input("end_frame", optional=True), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), + expr=""" + ( + $m := widgets.model; + $d := $lookup(widgets, "model.duration"); + $res := $lookup(widgets, "model.resolution"); + $contains($m, "pro") + ? ( + $base := $lookup({"720p": 0.15, "1080p": 0.3}, $res); + $perSec := $lookup({"720p": 0.05, "1080p": 0.075}, $res); + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + : ( + $base := $lookup({"720p": 0.075, "1080p": 0.2}, $res); + $perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res); + {"type":"usd","usd": $base + $perSec * ($d - 1)} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + video: Input.Video, + prompt: str, + seed: int, + end_frame: Input.Image | None = None, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2000) + validate_video_duration(video, min_duration=4, max_duration=55) + image_url = None + if end_frame is not None: + validate_image_aspect_ratio(end_frame, (1, 4), (4, 1)) + validate_image_dimensions(end_frame, min_width=128, min_height=128) + image_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") + results = await execute_task( + cls, + "/proxy/vidu/extend", + TaskExtendCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + resolution=model["resolution"], + video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading video"), + images=[image_url] if image_url else None, + ), + max_poll_attempts=480, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +def _generate_frame_inputs(count: int) -> list: + """Generate input widgets for a given number of frames.""" + inputs = [] + for i in range(1, count + 1): + inputs.extend( + [ + IO.String.Input( + f"prompt{i}", + multiline=True, + default="", + tooltip=f"Text prompt for frame {i} transition.", + ), + IO.Image.Input( + f"end_image{i}", + tooltip=f"End frame image for segment {i}. Aspect ratio must be between 1:4 and 4:1.", + ), + IO.Int.Input( + f"duration{i}", + default=4, + min=2, + max=7, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip=f"Duration for segment {i} in seconds.", + ), + ] + ) + return inputs + + +class ViduMultiFrameVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ViduMultiFrameVideoNode", + display_name="Vidu Multi-Frame Video Generation", + category="api node/video/Vidu", + description="Generate a video with multiple keyframe transitions.", + inputs=[ + IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]), + IO.Image.Input( + "start_image", + tooltip="The starting frame image. Aspect ratio must be between 1:4 and 4:1.", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.DynamicCombo.Input( + "frames", + options=[ + IO.DynamicCombo.Option("2", _generate_frame_inputs(2)), + IO.DynamicCombo.Option("3", _generate_frame_inputs(3)), + IO.DynamicCombo.Option("4", _generate_frame_inputs(4)), + IO.DynamicCombo.Option("5", _generate_frame_inputs(5)), + IO.DynamicCombo.Option("6", _generate_frame_inputs(6)), + IO.DynamicCombo.Option("7", _generate_frame_inputs(7)), + IO.DynamicCombo.Option("8", _generate_frame_inputs(8)), + IO.DynamicCombo.Option("9", _generate_frame_inputs(9)), + ], + tooltip="Number of keyframe transitions (2-9).", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=[ + "model", + "resolution", + "frames", + "frames.duration1", + "frames.duration2", + "frames.duration3", + "frames.duration4", + "frames.duration5", + "frames.duration6", + "frames.duration7", + "frames.duration8", + "frames.duration9", + ] + ), + expr=""" + ( + $m := widgets.model; + $n := $number(widgets.frames); + $is1080 := widgets.resolution = "1080p"; + $d1 := $lookup(widgets, "frames.duration1"); + $d2 := $lookup(widgets, "frames.duration2"); + $d3 := $n >= 3 ? $lookup(widgets, "frames.duration3") : 0; + $d4 := $n >= 4 ? $lookup(widgets, "frames.duration4") : 0; + $d5 := $n >= 5 ? $lookup(widgets, "frames.duration5") : 0; + $d6 := $n >= 6 ? $lookup(widgets, "frames.duration6") : 0; + $d7 := $n >= 7 ? $lookup(widgets, "frames.duration7") : 0; + $d8 := $n >= 8 ? $lookup(widgets, "frames.duration8") : 0; + $d9 := $n >= 9 ? $lookup(widgets, "frames.duration9") : 0; + $totalDuration := $d1 + $d2 + $d3 + $d4 + $d5 + $d6 + $d7 + $d8 + $d9; + $contains($m, "pro") + ? ( + $base := $is1080 ? 0.3 : 0.15; + $perSec := $is1080 ? 0.075 : 0.05; + {"type":"usd","usd": $n * $base + $perSec * $totalDuration} + ) + : ( + $base := $is1080 ? 0.2 : 0.075; + $perSec := $is1080 ? 0.05 : 0.025; + {"type":"usd","usd": $n * $base + $perSec * $totalDuration} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + start_image: Input.Image, + seed: int, + resolution: str, + frames: dict, + ) -> IO.NodeOutput: + validate_image_aspect_ratio(start_image, (1, 4), (4, 1)) + frame_count = int(frames["frames"]) + image_settings: list[FrameSetting] = [] + for i in range(1, frame_count + 1): + validate_image_aspect_ratio(frames[f"end_image{i}"], (1, 4), (4, 1)) + validate_string(frames[f"prompt{i}"], max_length=2000) + start_image_url = await upload_image_to_comfyapi( + cls, + start_image, + mime_type="image/png", + wait_label="Uploading start image", + ) + for i in range(1, frame_count + 1): + image_settings.append( + FrameSetting( + prompt=frames[f"prompt{i}"], + key_image=await upload_image_to_comfyapi( + cls, + frames[f"end_image{i}"], + mime_type="image/png", + wait_label=f"Uploading end image({i})", + ), + duration=frames[f"duration{i}"], + ) + ) + results = await execute_task( + cls, + "/proxy/vidu/multiframe", + TaskMultiFrameCreationRequest( + model=model, + seed=seed, + resolution=resolution, + start_image=start_image_url, + image_settings=image_settings, + ), + max_poll_attempts=480 * frame_count, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu3TextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu3TextToVideoNode", + display_name="Vidu Q3 Text-to-Video Generation", + category="api node/video/Vidu", + description="Generate video from a text prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq3-pro", + [ + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "3:4", "4:3", "1:1"], + tooltip="The aspect ratio of the output video.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + IO.DynamicCombo.Option( + "viduq3-turbo", + [ + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "3:4", "4:3", "1:1"], + tooltip="The aspect ratio of the output video.", + ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + ], + tooltip="Model to use for video generation.", + ), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A textual description for video generation, with a maximum length of 2000 characters.", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $d := $lookup(widgets, "model.duration"); + $contains(widgets.model, "turbo") + ? ( + $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res); + {"type":"usd","usd": $rate * $d} + ) + : ( + $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res); + {"type":"usd","usd": $rate * $d} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2000) + results = await execute_task( + cls, + VIDU_TEXT_TO_VIDEO, + TaskCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + aspect_ratio=model["aspect_ratio"], + resolution=model["resolution"], + audio=model["audio"], + ), + max_poll_attempts=640, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu3ImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu3ImageToVideoNode", + display_name="Vidu Q3 Image-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from an image and an optional prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq3-pro", + [ + IO.Combo.Input( + "resolution", + options=["720p", "1080p", "2K"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + IO.DynamicCombo.Option( + "viduq3-turbo", + [ + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + ], + tooltip="Model to use for video generation.", + ), + IO.Image.Input( + "image", + tooltip="An image to be used as the start frame of the generated video.", + ), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="An optional text prompt for video generation (max 2000 characters).", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $d := $lookup(widgets, "model.duration"); + $contains(widgets.model, "turbo") + ? ( + $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res); + {"type":"usd","usd": $rate * $d} + ) + : ( + $rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res); + {"type":"usd","usd": $rate * $d} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + image: Input.Image, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + validate_image_aspect_ratio(image, (1, 4), (4, 1)) + validate_string(prompt, max_length=2000) + results = await execute_task( + cls, + VIDU_IMAGE_TO_VIDEO, + TaskCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + resolution=model["resolution"], + audio=model["audio"], + images=[await upload_image_to_comfyapi(cls, image)], + ), + max_poll_attempts=720, + ) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) + + +class Vidu3StartEndToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Vidu3StartEndToVideoNode", + display_name="Vidu Q3 Start/End Frame-to-Video Generation", + category="api node/video/Vidu", + description="Generate a video from a start frame, an end frame, and a prompt.", + inputs=[ + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "viduq3-pro", + [ + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + IO.DynamicCombo.Option( + "viduq3-turbo", + [ + IO.Combo.Input( + "resolution", + options=["720p", "1080p"], + tooltip="Resolution of the output video.", + ), + IO.Int.Input( + "duration", + default=5, + min=1, + max=16, + step=1, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds.", + ), + IO.Boolean.Input( + "audio", + default=False, + tooltip="When enabled, outputs video with sound " + "(including dialogue and sound effects).", + ), + ], + ), + ], + tooltip="Model to use for video generation.", + ), + IO.Image.Input("first_frame"), + IO.Image.Input("end_frame"), + IO.String.Input( + "prompt", + multiline=True, + tooltip="Prompt description (max 2000 characters).", + ), + IO.Int.Input( + "seed", + default=1, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]), + expr=""" + ( + $res := $lookup(widgets, "model.resolution"); + $d := $lookup(widgets, "model.duration"); + $contains(widgets.model, "turbo") + ? ( + $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res); + {"type":"usd","usd": $rate * $d} + ) + : ( + $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res); + {"type":"usd","usd": $rate * $d} + ) + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: dict, + first_frame: Input.Image, + end_frame: Input.Image, + prompt: str, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, max_length=2000) + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + payload = TaskCreationRequest( + model=model["model"], + prompt=prompt, + duration=model["duration"], + seed=seed, + resolution=model["resolution"], + audio=model["audio"], + images=[ + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] + for frame in (first_frame, end_frame) + ], + ) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload) + return IO.NodeOutput(await download_url_to_video_output(results[0].url)) class ViduExtension(ComfyExtension): @@ -558,6 +1712,15 @@ class ViduExtension(ComfyExtension): ViduImageToVideoNode, ViduReferenceVideoNode, ViduStartEndToVideoNode, + Vidu2TextToVideoNode, + Vidu2ImageToVideoNode, + Vidu2ReferenceVideoNode, + Vidu2StartEndToVideoNode, + ViduExtendVideoNode, + ViduMultiFrameVideoNode, + Vidu3TextToVideoNode, + Vidu3ImageToVideoNode, + Vidu3StartEndToVideoNode, ] diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 17b680e13..e2afe7f9c 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -13,7 +13,9 @@ from comfy_api_nodes.util import ( poll_op, sync_op, tensor_to_base64_string, + upload_video_to_comfyapi, validate_audio_duration, + validate_video_duration, ) @@ -41,19 +43,25 @@ class Image2VideoInputField(BaseModel): audio_url: str | None = Field(None) +class Reference2VideoInputField(BaseModel): + prompt: str = Field(...) + negative_prompt: str | None = Field(None) + reference_video_urls: list[str] = Field(...) + + class Txt2ImageParametersField(BaseModel): size: str = Field(...) n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) class Image2ImageParametersField(BaseModel): size: str | None = Field(None) n: int = Field(1, description="Number of images to generate.") # we support only value=1 seed: int = Field(..., ge=0, le=2147483647) - watermark: bool = Field(True) + watermark: bool = Field(False) class Text2VideoParametersField(BaseModel): @@ -61,7 +69,7 @@ class Text2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") @@ -71,11 +79,19 @@ class Image2VideoParametersField(BaseModel): seed: int = Field(..., ge=0, le=2147483647) duration: int = Field(5, ge=5, le=15) prompt_extend: bool = Field(True) - watermark: bool = Field(True) + watermark: bool = Field(False) audio: bool = Field(False, description="Whether to generate audio automatically.") shot_type: str = Field("single") +class Reference2VideoParametersField(BaseModel): + size: str = Field(...) + duration: int = Field(5, ge=5, le=15) + shot_type: str = Field("single") + seed: int = Field(..., ge=0, le=2147483647) + watermark: bool = Field(False) + + class Text2ImageTaskCreationRequest(BaseModel): model: str = Field(...) input: Text2ImageInputField = Field(...) @@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel): parameters: Image2VideoParametersField = Field(...) +class Reference2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + input: Reference2VideoInputField = Field(...) + parameters: Reference2VideoParametersField = Field(...) + + class TaskCreationOutputField(BaseModel): task_id: str = Field(...) task_status: str = Field(...) @@ -205,12 +227,14 @@ class WanTextToImageApi(IO.ComfyNode): default=True, tooltip="Whether to enhance the prompt with AI assistance.", optional=True, + advanced=True, ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, + advanced=True, ), ], outputs=[ @@ -222,6 +246,9 @@ class WanTextToImageApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -234,7 +261,7 @@ class WanTextToImageApi(IO.ComfyNode): height: int = 1024, seed: int = 0, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, ): initial_response = await sync_op( cls, @@ -327,9 +354,10 @@ class WanImageToImageApi(IO.ComfyNode): ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, + advanced=True, ), ], outputs=[ @@ -341,6 +369,9 @@ class WanImageToImageApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.03}""", + ), ) @classmethod @@ -353,7 +384,7 @@ class WanImageToImageApi(IO.ComfyNode): # width: int = 1024, # height: int = 1024, seed: int = 0, - watermark: bool = True, + watermark: bool = False, ): n_images = get_number_of_images(image) if n_images not in (1, 2): @@ -467,18 +498,21 @@ class WanTextToVideoApi(IO.ComfyNode): default=False, optional=True, tooltip="If no audio input is provided, generate audio automatically.", + advanced=True, ), IO.Boolean.Input( "prompt_extend", default=True, tooltip="Whether to enhance the prompt with AI assistance.", optional=True, + advanced=True, ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, + advanced=True, ), IO.Combo.Input( "shot_type", @@ -487,6 +521,7 @@ class WanTextToVideoApi(IO.ComfyNode): "single continuous shot or multiple shots with cuts. " "This parameter takes effect only when prompt_extend is True.", optional=True, + advanced=True, ), ], outputs=[ @@ -498,6 +533,17 @@ class WanTextToVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "size"]), + expr=""" + ( + $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; + $resKey := $substringBefore(widgets.size, ":"); + $pps := $lookup($ppsTable, $resKey); + { "type": "usd", "usd": $round($pps * widgets.duration, 2) } + ) + """, + ), ) @classmethod @@ -512,7 +558,7 @@ class WanTextToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if "480p" in size and model == "wan2.6-t2v": @@ -628,18 +674,21 @@ class WanImageToVideoApi(IO.ComfyNode): default=False, optional=True, tooltip="If no audio input is provided, generate audio automatically.", + advanced=True, ), IO.Boolean.Input( "prompt_extend", default=True, tooltip="Whether to enhance the prompt with AI assistance.", optional=True, + advanced=True, ), IO.Boolean.Input( "watermark", - default=True, + default=False, tooltip="Whether to add an AI-generated watermark to the result.", optional=True, + advanced=True, ), IO.Combo.Input( "shot_type", @@ -648,6 +697,7 @@ class WanImageToVideoApi(IO.ComfyNode): "single continuous shot or multiple shots with cuts. " "This parameter takes effect only when prompt_extend is True.", optional=True, + advanced=True, ), ], outputs=[ @@ -659,6 +709,16 @@ class WanImageToVideoApi(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]), + expr=""" + ( + $ppsTable := { "480p": 0.05, "720p": 0.1, "1080p": 0.15 }; + $pps := $lookup($ppsTable, widgets.resolution); + { "type": "usd", "usd": $round($pps * widgets.duration, 2) } + ) + """, + ), ) @classmethod @@ -674,7 +734,7 @@ class WanImageToVideoApi(IO.ComfyNode): seed: int = 0, generate_audio: bool = False, prompt_extend: bool = True, - watermark: bool = True, + watermark: bool = False, shot_type: str = "single", ): if get_number_of_images(image) != 1: @@ -721,6 +781,161 @@ class WanImageToVideoApi(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) +class WanReferenceVideoApi(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WanReferenceVideoApi", + display_name="Wan Reference to Video", + category="api node/video/Wan", + description="Use the character and voice from input videos, combined with a prompt, " + "to generate a new video that maintains character consistency.", + inputs=[ + IO.Combo.Input("model", options=["wan2.6-r2v"]), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. Supports English and Chinese. " + "Use identifiers such as `character1` and `character2` to refer to the reference characters.", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative prompt describing what to avoid.", + ), + IO.Autogrow.Input( + "reference_videos", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("reference_video"), + names=["character1", "character2", "character3"], + min=1, + ), + ), + IO.Combo.Input( + "size", + options=[ + "720p: 1:1 (960x960)", + "720p: 16:9 (1280x720)", + "720p: 9:16 (720x1280)", + "720p: 4:3 (1088x832)", + "720p: 3:4 (832x1088)", + "1080p: 1:1 (1440x1440)", + "1080p: 16:9 (1920x1080)", + "1080p: 9:16 (1080x1920)", + "1080p: 4:3 (1632x1248)", + "1080p: 3:4 (1248x1632)", + ], + ), + IO.Int.Input( + "duration", + default=5, + min=5, + max=10, + step=5, + display_mode=IO.NumberDisplay.slider, + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + ), + IO.Combo.Input( + "shot_type", + options=["single", "multi"], + tooltip="Specifies the shot type for the generated video, that is, whether the video is a " + "single continuous shot or multiple shots with cuts.", + advanced=True, + ), + IO.Boolean.Input( + "watermark", + default=False, + tooltip="Whether to add an AI-generated watermark to the result.", + advanced=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["size", "duration"]), + expr=""" + ( + $rate := $contains(widgets.size, "1080p") ? 0.15 : 0.10; + $inputMin := 2 * $rate; + $inputMax := 5 * $rate; + $outputPrice := widgets.duration * $rate; + { + "type": "range_usd", + "min_usd": $inputMin + $outputPrice, + "max_usd": $inputMax + $outputPrice + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + negative_prompt: str, + reference_videos: IO.Autogrow.Type, + size: str, + duration: int, + seed: int, + shot_type: str, + watermark: bool, + ): + reference_video_urls = [] + for i in reference_videos: + validate_video_duration(reference_videos[i], min_duration=2, max_duration=30) + for i in reference_videos: + reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i])) + width, height = RES_IN_PARENS.search(size).groups() + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Reference2VideoTaskCreationRequest( + model=model, + input=Reference2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls + ), + parameters=Reference2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + shot_type=shot_type, + watermark=watermark, + seed=seed, + ), + ), + ) + if not initial_response.output: + raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), + response_model=VideoTaskStatusResponse, + status_extractor=lambda x: x.output.task_status, + poll_interval=6, + max_poll_attempts=280, + ) + return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) + + class WanApiExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -729,6 +944,7 @@ class WanApiExtension(ComfyExtension): WanImageToImageApi, WanTextToVideoApi, WanImageToVideoApi, + WanReferenceVideoApi, ] diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py new file mode 100644 index 000000000..c59fafd3b --- /dev/null +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -0,0 +1,178 @@ +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.wavespeed import ( + FlashVSRRequest, + TaskCreatedResponse, + TaskResultResponse, + SeedVR2ImageRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_video_output, + poll_op, + sync_op, + upload_video_to_comfyapi, + validate_container_format_is_mp4, + validate_video_duration, + upload_images_to_comfyapi, + get_number_of_images, + download_url_to_image_tensor, +) + + +class WavespeedFlashVSRNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedFlashVSRNode", + display_name="FlashVSR Video Upscale", + category="api node/video/WaveSpeed", + description="Fast, high-quality video upscaler that " + "boosts resolution and restores clarity for low-resolution or blurry footage.", + inputs=[ + IO.Video.Input("video"), + IO.Combo.Input("target_resolution", options=["720p", "1080p", "2K", "4K"]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["target_resolution"]), + expr=""" + ( + $price_for_1sec := {"720p": 0.012, "1080p": 0.018, "2k": 0.024, "4k": 0.032}; + { + "type":"usd", + "usd": $lookup($price_for_1sec, widgets.target_resolution), + "format":{"suffix": "/second", "approximate": true} + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + target_resolution: str, + ) -> IO.NodeOutput: + validate_container_format_is_mp4(video) + validate_video_duration(video, min_duration=5, max_duration=60 * 10) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/wavespeed/api/v3/wavespeed-ai/flashvsr", method="POST"), + response_model=TaskCreatedResponse, + data=FlashVSRRequest( + target_resolution=target_resolution.lower(), + video=await upload_video_to_comfyapi(cls, video), + duration=video.get_duration(), + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.outputs[0])) + + +class WavespeedImageUpscaleNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="WavespeedImageUpscaleNode", + display_name="WaveSpeed Image Upscale", + category="api node/image/WaveSpeed", + description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", + inputs=[ + IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), + IO.Image.Input("image"), + IO.Combo.Input("target_resolution", options=["2K", "4K", "8K"]), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $prices := {"seedvr2": 0.01, "ultimate": 0.06}; + {"type":"usd", "usd": $lookup($prices, widgets.model)} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + model: str, + image: Input.Image, + target_resolution: str, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Exactly one input image is required.") + if model == "SeedVR2": + model_path = "seedvr2/image" + else: + model_path = "ultimate-image-upscaler" + initial_res = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/wavespeed-ai/{model_path}", method="POST"), + response_model=TaskCreatedResponse, + data=SeedVR2ImageRequest( + target_resolution=target_resolution.lower(), + image=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + ), + ) + if initial_res.code != 200: + raise ValueError(f"Task creation fails with code={initial_res.code} and message={initial_res.message}") + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wavespeed/api/v3/predictions/{initial_res.data.id}/result"), + response_model=TaskResultResponse, + status_extractor=lambda x: "failed" if x.data is None else x.data.status, + poll_interval=10.0, + max_poll_attempts=480, + ) + if final_response.code != 200: + raise ValueError( + f"Task processing failed with code={final_response.code} and message={final_response.message}" + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.outputs[0])) + + +class WavespeedExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + WavespeedFlashVSRNode, + WavespeedImageUpscaleNode, + ] + + +async def comfy_entrypoint() -> WavespeedExtension: + return WavespeedExtension() diff --git a/comfy_api_nodes/redocly-dev.yaml b/comfy_api_nodes/redocly-dev.yaml deleted file mode 100644 index d9e3cab70..000000000 --- a/comfy_api_nodes/redocly-dev.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. -# This is used for development purposes to generate stubs for unreleased API endpoints. -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes'] - matchStrategy: all diff --git a/comfy_api_nodes/redocly.yaml b/comfy_api_nodes/redocly.yaml deleted file mode 100644 index d102345b1..000000000 --- a/comfy_api_nodes/redocly.yaml +++ /dev/null @@ -1,10 +0,0 @@ -# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes. - -apis: - filter: - root: openapi.yaml - decorators: - filter-in: - property: tags - value: ['API Nodes', 'Released'] - matchStrategy: all diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 4cc22abfb..0cb9a47c7 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -9,9 +9,13 @@ from .client import ( from .conversions import ( audio_bytes_to_audio_input, audio_input_to_mp3, + audio_ndarray_to_bytesio, + audio_tensor_to_contiguous_ndarray, audio_to_base64_string, bytesio_to_image_tensor, + convert_mask_to_image, downscale_image_tensor, + downscale_image_tensor_by_max_side, image_tensor_pair_to_batch, pil_to_bytesio, resize_mask_to_image, @@ -26,12 +30,15 @@ from .conversions import ( from .download_helpers import ( download_url_as_bytesio, download_url_to_bytesio, + download_url_to_file_3d, download_url_to_image_tensor, download_url_to_video_output, ) from .upload_helpers import ( + upload_3d_model_to_comfyapi, upload_audio_to_comfyapi, upload_file_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, upload_video_to_comfyapi, ) @@ -58,21 +65,28 @@ __all__ = [ "sync_op", "sync_op_raw", # Upload helpers + "upload_3d_model_to_comfyapi", "upload_audio_to_comfyapi", "upload_file_to_comfyapi", + "upload_image_to_comfyapi", "upload_images_to_comfyapi", "upload_video_to_comfyapi", # Download helpers "download_url_as_bytesio", "download_url_to_bytesio", + "download_url_to_file_3d", "download_url_to_image_tensor", "download_url_to_video_output", # Conversions "audio_bytes_to_audio_input", "audio_input_to_mp3", + "audio_ndarray_to_bytesio", + "audio_tensor_to_contiguous_ndarray", "audio_to_base64_string", "bytesio_to_image_tensor", + "convert_mask_to_image", "downscale_image_tensor", + "downscale_image_tensor_by_max_side", "image_tensor_pair_to_batch", "pil_to_bytesio", "resize_mask_to_image", diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py index 491e6b6a8..648defe3d 100644 --- a/comfy_api_nodes/util/_helpers.py +++ b/comfy_api_nodes/util/_helpers.py @@ -1,16 +1,22 @@ import asyncio import contextlib import os +import re import time from collections.abc import Callable from io import BytesIO +from yarl import URL + from comfy.cli_args import args from comfy.model_management import processing_interrupted from comfy_api.latest import IO from .common_exceptions import ProcessingInterrupted +_HAS_PCT_ESC = re.compile(r"%[0-9A-Fa-f]{2}") # any % followed by 2 hex digits +_HAS_BAD_PCT = re.compile(r"%(?![0-9A-Fa-f]{2})") # any % not followed by 2 hex digits + def is_processing_interrupted() -> bool: """Return True if user/runtime requested interruption.""" @@ -69,3 +75,17 @@ def get_fs_object_size(path_or_object: str | BytesIO) -> int: if isinstance(path_or_object, str): return os.path.getsize(path_or_object) return len(path_or_object.getvalue()) + + +def to_aiohttp_url(url: str) -> URL: + """If `url` appears to be already percent-encoded (contains at least one valid %HH + escape and no malformed '%' sequences) and contains no raw whitespace/control + characters preserve the original encoding byte-for-byte (important for signed/presigned URLs). + Otherwise, return `URL(url)` and allow yarl to normalize/quote as needed.""" + if any(c.isspace() for c in url) or any(ord(c) < 0x20 for c in url): + # Avoid encoded=True if URL contains raw whitespace/control chars + return URL(url) + if _HAS_PCT_ESC.search(url) and not _HAS_BAD_PCT.search(url): + # Preserve encoding only if it appears pre-encoded AND has no invalid % sequences + return URL(url, encoded=True) + return URL(url) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index bf37cba5f..9d730b81a 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -57,6 +57,7 @@ class _RequestConfig: files: dict[str, Any] | list[tuple[str, Any]] | None multipart_parser: Callable | None max_retries: int + max_retries_on_rate_limit: int retry_delay: float retry_backoff: float wait_label: str = "Waiting" @@ -65,6 +66,8 @@ class _RequestConfig: final_label_on_success: str | None = "Completed" progress_origin_ts: float | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None + is_rate_limited: Callable[[int, Any], bool] | None = None + response_header_validator: Callable[[dict[str, str]], None] | None = None @dataclass @@ -78,10 +81,10 @@ class _PollUIState: active_since: float | None = None # start time of current active interval (None if queued) -_RETRY_STATUS = {408, 429, 500, 502, 503, 504} +_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] -QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] async def sync_op( @@ -103,6 +106,8 @@ async def sync_op( final_label_on_success: str | None = "Completed", progress_origin_ts: float | None = None, monitor_progress: bool = True, + max_retries_on_rate_limit: int = 16, + is_rate_limited: Callable[[int, Any], bool] | None = None, ) -> M: raw = await sync_op_raw( cls, @@ -122,6 +127,8 @@ async def sync_op( final_label_on_success=final_label_on_success, progress_origin_ts=progress_origin_ts, monitor_progress=monitor_progress, + max_retries_on_rate_limit=max_retries_on_rate_limit, + is_rate_limited=is_rate_limited, ) if not isinstance(raw, dict): raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") @@ -141,11 +148,11 @@ async def poll_op( queued_statuses: list[str | int] | None = None, data: BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 120, + max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, - max_retries_per_poll: int = 3, + max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, - retry_backoff_per_poll: float = 2.0, + retry_backoff_per_poll: float = 1.4, estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, @@ -194,11 +201,15 @@ async def sync_op_raw( final_label_on_success: str | None = "Completed", progress_origin_ts: float | None = None, monitor_progress: bool = True, + max_retries_on_rate_limit: int = 16, + is_rate_limited: Callable[[int, Any], bool] | None = None, + response_header_validator: Callable[[dict[str, str]], None] | None = None, ) -> dict[str, Any] | bytes: """ Make a single network request. - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). - If as_binary=True: returns bytes. + - response_header_validator: optional callback receiving response headers dict """ if isinstance(data, BaseModel): data = data.model_dump(exclude_none=True) @@ -222,6 +233,9 @@ async def sync_op_raw( final_label_on_success=final_label_on_success, progress_origin_ts=progress_origin_ts, price_extractor=price_extractor, + max_retries_on_rate_limit=max_retries_on_rate_limit, + is_rate_limited=is_rate_limited, + response_header_validator=response_header_validator, ) return await _request_base(cfg, expect_binary=as_binary) @@ -238,11 +252,11 @@ async def poll_op_raw( queued_statuses: list[str | int] | None = None, data: dict[str, Any] | BaseModel | None = None, poll_interval: float = 5.0, - max_poll_attempts: int = 120, + max_poll_attempts: int = 160, timeout_per_poll: float = 120.0, - max_retries_per_poll: int = 3, + max_retries_per_poll: int = 10, retry_delay_per_poll: float = 1.0, - retry_backoff_per_poll: float = 2.0, + retry_backoff_per_poll: float = 1.4, estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, @@ -430,9 +444,9 @@ def _display_text( if status: display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") if price is not None: - p = f"{float(price):,.4f}".rstrip("0").rstrip(".") + p = f"{float(price) * 211:,.1f}".rstrip("0").rstrip(".") if p != "0": - display_lines.append(f"Price: ${p}") + display_lines.append(f"Price: {p} credits") if text is not None: display_lines.append(text) if display_lines: @@ -506,7 +520,7 @@ def _friendly_http_message(status: int, body: Any) -> str: if status == 409: return "There is a problem with your account. Please contact support@comfy.org." if status == 429: - return "Rate Limit Exceeded: Please try again later." + return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again." try: if isinstance(body, dict): err = body.get("error") @@ -586,6 +600,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() attempt = 0 delay = cfg.retry_delay + rate_limit_attempts = 0 + rate_limit_delay = cfg.retry_delay operation_succeeded: bool = False final_elapsed_seconds: int | None = None extracted_price: float | None = None @@ -653,17 +669,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): payload_headers["Content-Type"] = "application/json" payload_kw["json"] = cfg.data or {} - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=dict(payload_headers) if payload_headers else None, - request_params=dict(params) if params else None, - request_data=request_body_log, - ) - except Exception as _log_e: - logging.debug("[DEBUG] request logging failed: %s", _log_e) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + ) req_coro = sess.request(method, url, params=params, **payload_kw) req_task = asyncio.create_task(req_coro) @@ -688,41 +701,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): body = await resp.json() except (ContentTypeError, json.JSONDecodeError): body = await resp.text() - if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: + should_retry = False + wait_time = 0.0 + retry_label = "" + is_rl = resp.status == 429 or ( + cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body) + ) + if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit: + rate_limit_attempts += 1 + wait_time = min(rate_limit_delay, 30.0) + rate_limit_delay *= cfg.retry_backoff + retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}" + should_retry = True + elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries: + wait_time = delay + delay *= cfg.retry_backoff + retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}" + should_retry = True + + if should_retry: logging.warning( - "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", + "HTTP %s %s -> %s. Waiting %.2fs (%s).", method, url, resp.status, - delay, - attempt, - cfg.max_retries, + wait_time, + retry_label, ) - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=body, - error_message=_friendly_http_message(resp.status, body), - ) - except Exception as _log_e: - logging.debug("[DEBUG] response logging failed: %s", _log_e) - - await sleep_with_interrupt( - delay, - cfg.node_cls, - cfg.wait_label if cfg.monitor_progress else None, - start_time if cfg.monitor_progress else None, - cfg.estimated_total, - display_callback=_display_time_progress if cfg.monitor_progress else None, - ) - delay *= cfg.retry_backoff - continue - msg = _friendly_http_message(resp.status, body) - try: request_logger.log_request_response( operation_id=operation_id, request_method=method, @@ -730,10 +735,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): response_status_code=resp.status, response_headers=dict(resp.headers), response_content=body, - error_message=msg, + error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)", ) - except Exception as _log_e: - logging.debug("[DEBUG] response logging failed: %s", _log_e) + await sleep_with_interrupt( + wait_time, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + continue + msg = _friendly_http_message(resp.status, body) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) raise Exception(msg) if expect_binary: @@ -751,19 +773,22 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total ) bytes_payload = bytes(buff) + resp_headers = {k.lower(): v for k, v in resp.headers.items()} + if cfg.price_extractor: + with contextlib.suppress(Exception): + extracted_price = cfg.price_extractor(resp_headers) + if cfg.response_header_validator: + cfg.response_header_validator(resp_headers) operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=bytes_payload, - ) - except Exception as _log_e: - logging.debug("[DEBUG] response logging failed: %s", _log_e) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=resp_headers, + response_content=bytes_payload, + ) return bytes_payload else: try: @@ -780,45 +805,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None operation_succeeded = True final_elapsed_seconds = int(time.monotonic() - start_time) - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=response_content_to_log, - ) - except Exception as _log_e: - logging.debug("[DEBUG] response logging failed: %s", _log_e) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, + ) return payload except ProcessingInterrupted: logging.debug("Polling was interrupted by user") raise except (ClientError, OSError) as e: - if attempt <= cfg.max_retries: + if (attempt - rate_limit_attempts) <= cfg.max_retries: logging.warning( "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", method, url, delay, - attempt, + attempt - rate_limit_attempts, cfg.max_retries, str(e), ) - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=dict(payload_headers) if payload_headers else None, - request_params=dict(params) if params else None, - request_data=request_body_log, - error_message=f"{type(e).__name__}: {str(e)} (will retry)", - ) - except Exception as _log_e: - logging.debug("[DEBUG] request error logging failed: %s", _log_e) + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) await sleep_with_interrupt( delay, cfg.node_cls, @@ -831,23 +850,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): continue diag = await _diagnose_connectivity() if not diag["internet_accessible"]: - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=dict(payload_headers) if payload_headers else None, - request_params=dict(params) if params else None, - request_data=request_body_log, - error_message=f"LocalNetworkError: {str(e)}", - ) - except Exception as _log_e: - logging.debug("[DEBUG] final error logging failed: %s", _log_e) - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - try: request_logger.log_request_response( operation_id=operation_id, request_method=method, @@ -855,10 +857,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool): request_headers=dict(payload_headers) if payload_headers else None, request_params=dict(params) if params else None, request_data=request_body_log, - error_message=f"ApiServerError: {str(e)}", + error_message=f"LocalNetworkError: {str(e)}", ) - except Exception as _log_e: - logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"ApiServerError: {str(e)}", + ) raise ApiServerError( f"The API server at {default_base_url()} is currently unreachable. " f"The service may be experiencing issues." diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index d64239c86..82b6d22a5 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -55,16 +55,15 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to def tensor_to_bytesio( image: torch.Tensor, - name: str | None = None, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", + *, + total_pixels: int | None = 2048 * 2048, + mime_type: str | None = "image/png", ) -> BytesIO: """Converts a torch.Tensor image to a named BytesIO object. Args: image: Input torch.Tensor image. - name: Optional filename for the BytesIO object. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). Returns: @@ -75,17 +74,18 @@ def tensor_to_bytesio( pil_image = tensor_to_pil(image, total_pixels=total_pixels) img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) - img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + img_binary.name = f"{uuid.uuid4()}.{mimetype_to_extension(mime_type)}" return img_binary -def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: +def tensor_to_pil(image: torch.Tensor, total_pixels: int | None = 2048 * 2048) -> Image.Image: """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" if len(image.shape) > 3: image = image[0] # TODO: remove alpha if not allowed and present input_tensor = image.cpu() - input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() + if total_pixels is not None: + input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() image_np = (input_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) return img @@ -93,14 +93,14 @@ def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image def tensor_to_base64_string( image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> str: """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. Args: image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). Returns: @@ -144,16 +144,31 @@ def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) return s +def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: + """Downscale input image tensor so the largest dimension is at most max_side pixels.""" + samples = image.movedim(-1, 1) + height, width = samples.shape[2], samples.shape[3] + max_dim = max(width, height) + if max_dim <= max_side: + return image + scale_by = max_side / max_dim + new_width = round(width * scale_by) + new_height = round(height * scale_by) + s = common_upscale(samples, new_width, new_height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + def tensor_to_data_uri( image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, + total_pixels: int | None = 2048 * 2048, mime_type: str = "image/png", ) -> str: """Converts a tensor image to a Data URI string. Args: image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. + total_pixels: Maximum total pixels for downscaling. If None, no downscaling is performed. mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). Returns: @@ -451,6 +466,12 @@ def resize_mask_to_image( return mask +def convert_mask_to_image(mask: Input.Image) -> torch.Tensor: + """Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.""" + mask = mask.unsqueeze(-1) + return torch.cat([mask] * 3, dim=-1) + + def text_filepath_to_base64_string(filepath: str) -> str: """Converts a text file to a base64 string.""" with open(filepath, "rb") as f: diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py index 3e0d0352d..aa588d038 100644 --- a/comfy_api_nodes/util/download_helpers.py +++ b/comfy_api_nodes/util/download_helpers.py @@ -11,7 +11,8 @@ import torch from aiohttp.client_exceptions import ClientError, ContentTypeError from comfy_api.latest import IO as COMFY_IO -from comfy_api.latest import InputImpl +from comfy_api.latest import InputImpl, Types +from folder_paths import get_output_directory from . import request_logger from ._helpers import ( @@ -19,6 +20,7 @@ from ._helpers import ( get_auth_header, is_processing_interrupted, sleep_with_interrupt, + to_aiohttp_url, ) from .client import _diagnose_connectivity from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted @@ -94,7 +96,7 @@ async def download_url_to_bytesio( monitor_task = asyncio.create_task(_monitor()) - req_task = asyncio.create_task(session.get(url, headers=headers)) + req_task = asyncio.create_task(session.get(to_aiohttp_url(url), headers=headers)) done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) if monitor_task in done and req_task in pending: @@ -165,27 +167,25 @@ async def download_url_to_bytesio( with contextlib.suppress(Exception): dest.seek(0) - with contextlib.suppress(Exception): - request_logger.log_request_response( - operation_id=op_id, - request_method="GET", - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=f"[streamed {written} bytes to dest]", - ) + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=f"[streamed {written} bytes to dest]", + ) return except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None except (ClientError, OSError) as e: if attempt <= max_retries: - with contextlib.suppress(Exception): - request_logger.log_request_response( - operation_id=op_id, - request_method="GET", - request_url=url, - error_message=f"{type(e).__name__}: {str(e)} (will retry)", - ) + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) await sleep_with_interrupt(delay, cls, None, None, None) delay *= retry_backoff continue @@ -260,3 +260,38 @@ def _generate_operation_id(method: str, url: str, attempt: int) -> str: except Exception: slug = "download" return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +async def download_url_to_file_3d( + url: str, + file_format: str, + *, + task_id: str | None = None, + timeout: float | None = None, + max_retries: int = 5, + cls: type[COMFY_IO.ComfyNode] = None, +) -> Types.File3D: + """Downloads a 3D model file from a URL into memory as BytesIO. + + If task_id is provided, also writes the file to disk in the output directory + for backward compatibility with the old save-to-disk behavior. + """ + file_format = file_format.lstrip(".").lower() + data = BytesIO() + await download_url_to_bytesio( + url, + data, + timeout=timeout, + max_retries=max_retries, + cls=cls, + ) + + if task_id is not None: + # This is only for backward compatability with current behavior when every 3D node is output node + # All new API nodes should not use "task_id" and instead users should use "SaveGLB" node to save results + output_dir = Path(get_output_directory()) + output_path = output_dir / f"{task_id}.{file_format}" + output_path.write_bytes(data.getvalue()) + data.seek(0) + + return Types.File3D(source=data, file_format=file_format) diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py index e0cb4428d..fe0543d9b 100644 --- a/comfy_api_nodes/util/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -8,7 +8,6 @@ from typing import Any import folder_paths -# Get the logger instance logger = logging.getLogger(__name__) @@ -91,38 +90,41 @@ def log_request_response( Filenames are sanitized and length-limited for cross-platform safety. If we still fail to write, we fall back to appending into api.log. """ - log_dir = get_log_directory() - filepath = _build_log_filepath(log_dir, operation_id, request_url) - - log_content: list[str] = [] - log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") - log_content.append(f"Operation ID: {operation_id}") - log_content.append("-" * 30 + " REQUEST " + "-" * 30) - log_content.append(f"Method: {request_method}") - log_content.append(f"URL: {request_url}") - if request_headers: - log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") - if request_params: - log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") - if request_data is not None: - log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") - - log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) - if response_status_code is not None: - log_content.append(f"Status Code: {response_status_code}") - if response_headers: - log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") - if response_content is not None: - log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") - if error_message: - log_content.append(f"Error:\n{error_message}") - try: - with open(filepath, "w", encoding="utf-8") as f: - f.write("\n".join(log_content)) - logger.debug("API log saved to: %s", filepath) - except Exception as e: - logger.error("Error writing API log to %s: %s", filepath, str(e)) + log_dir = get_log_directory() + filepath = _build_log_filepath(log_dir, operation_id, request_url) + + log_content: list[str] = [] + log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}") + log_content.append(f"Operation ID: {operation_id}") + log_content.append("-" * 30 + " REQUEST " + "-" * 30) + log_content.append(f"Method: {request_method}") + log_content.append(f"URL: {request_url}") + if request_headers: + log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}") + if request_params: + log_content.append(f"Params:\n{_format_data_for_logging(request_params)}") + if request_data is not None: + log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}") + + log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30) + if response_status_code is not None: + log_content.append(f"Status Code: {response_status_code}") + if response_headers: + log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}") + if response_content is not None: + log_content.append(f"Content:\n{_format_data_for_logging(response_content)}") + if error_message: + log_content.append(f"Error:\n{error_message}") + + try: + with open(filepath, "w", encoding="utf-8") as f: + f.write("\n".join(log_content)) + logger.debug("API log saved to: %s", filepath) + except Exception as e: + logger.error("Error writing API log to %s: %s", filepath, str(e)) + except Exception as _log_e: + logging.debug("[DEBUG] log_request_response failed: %s", _log_e) if __name__ == '__main__': diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py index b8d33f4d1..6d1d107a1 100644 --- a/comfy_api_nodes/util/upload_helpers.py +++ b/comfy_api_nodes/util/upload_helpers.py @@ -43,27 +43,41 @@ class UploadResponse(BaseModel): async def upload_images_to_comfyapi( cls: type[IO.ComfyNode], - image: torch.Tensor, + image: torch.Tensor | list[torch.Tensor], *, max_images: int = 8, mime_type: str | None = None, wait_label: str | None = "Uploading", show_batch_index: bool = True, + total_pixels: int | None = 2048 * 2048, ) -> list[str]: """ Uploads images to ComfyUI API and returns download URLs. To upload multiple images, stack them in the batch dimension first. """ + tensors: list[torch.Tensor] = [] + if isinstance(image, list): + for img in image: + is_batch = len(img.shape) > 3 + if is_batch: + tensors.extend(img[i] for i in range(img.shape[0])) + else: + tensors.append(img) + else: + is_batch = len(image.shape) > 3 + if is_batch: + tensors.extend(image[i] for i in range(image.shape[0])) + else: + tensors.append(image) + # if batched, try to upload each file if max_images is greater than 0 download_urls: list[str] = [] - is_batch = len(image.shape) > 3 - batch_len = image.shape[0] if is_batch else 1 - num_to_upload = min(batch_len, max_images) + num_to_upload = min(len(tensors), max_images) batch_start_ts = time.monotonic() for idx in range(num_to_upload): - tensor = image[idx] if is_batch else image - img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + tensor = tensors[idx] + img_io = tensor_to_bytesio(tensor, total_pixels=total_pixels, mime_type=mime_type) effective_label = wait_label if wait_label and show_batch_index and num_to_upload > 1: @@ -74,6 +88,28 @@ async def upload_images_to_comfyapi( return download_urls +async def upload_image_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + mime_type: str | None = None, + wait_label: str | None = "Uploading", + total_pixels: int | None = 2048 * 2048, +) -> str: + """Uploads a single image to ComfyUI API and returns its download URL.""" + return ( + await upload_images_to_comfyapi( + cls, + image, + max_images=1, + mime_type=mime_type, + wait_label=wait_label, + show_batch_index=False, + total_pixels=total_pixels, + ) + )[0] + + async def upload_audio_to_comfyapi( cls: type[IO.ComfyNode], audio: Input.Audio, @@ -81,7 +117,6 @@ async def upload_audio_to_comfyapi( container_format: str = "mp4", codec_name: str = "aac", mime_type: str = "audio/mp4", - filename: str = "uploaded_audio.mp4", ) -> str: """ Uploads a single audio input to ComfyUI API and returns its download URL. @@ -91,7 +126,7 @@ async def upload_audio_to_comfyapi( waveform: torch.Tensor = audio["waveform"] audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) - return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) + return await upload_file_to_comfyapi(cls, audio_bytes_io, f"{uuid.uuid4()}.{container_format}", mime_type) async def upload_video_to_comfyapi( @@ -119,7 +154,7 @@ async def upload_video_to_comfyapi( raise ValueError(f"Could not verify video duration from source: {e}") from e upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" + filename = f"{uuid.uuid4()}.{container.value.lower()}" # Convert VideoInput to BytesIO using specified container/codec video_bytes_io = BytesIO() @@ -129,6 +164,27 @@ async def upload_video_to_comfyapi( return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) +_3D_MIME_TYPES = { + "glb": "model/gltf-binary", + "obj": "model/obj", + "fbx": "application/octet-stream", +} + + +async def upload_3d_model_to_comfyapi( + cls: type[IO.ComfyNode], + model_3d: Types.File3D, + file_format: str, +) -> str: + """Uploads a 3D model file to ComfyUI API and returns its download URL.""" + return await upload_file_to_comfyapi( + cls, + model_3d.get_data(), + f"{uuid.uuid4()}.{file_format}", + _3D_MIME_TYPES.get(file_format, "application/octet-stream"), + ) + + async def upload_file_to_comfyapi( cls: type[IO.ComfyNode], file_bytes_io: BytesIO, @@ -220,17 +276,14 @@ async def upload_file( monitor_task = asyncio.create_task(_monitor()) sess: aiohttp.ClientSession | None = None try: - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers or None, - request_params=None, - request_data=f"[File data {len(data)} bytes]", - ) - except Exception as e: - logging.debug("[DEBUG] upload request logging failed: %s", e) + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_params=None, + request_data=f"[File data {len(data)} bytes]", + ) sess = aiohttp.ClientSession(timeout=timeout) req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) @@ -276,31 +329,27 @@ async def upload_file( delay *= retry_backoff continue raise Exception(f"Failed to upload (HTTP {resp.status}).") - try: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content="File uploaded successfully.", - ) - except Exception as e: - logging.debug("[DEBUG] upload response logging failed: %s", e) + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) return except asyncio.CancelledError: raise ProcessingInterrupted("Task cancelled") from None except (aiohttp.ClientError, OSError) as e: if attempt <= max_retries: - with contextlib.suppress(Exception): - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers or None, - request_data=f"[File data {len(data)} bytes]", - error_message=f"{type(e).__name__}: {str(e)} (will retry)", - ) + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_data=f"[File data {len(data)} bytes]", + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) await sleep_with_interrupt( delay, cls, diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py new file mode 100644 index 000000000..d455d08e8 --- /dev/null +++ b/comfy_execution/cache_provider.py @@ -0,0 +1,138 @@ +from typing import Any, Optional, Tuple, List +import hashlib +import json +import logging +import threading + +# Public types — source of truth is comfy_api.latest._caching +from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported) + +_logger = logging.getLogger(__name__) + + +_providers: List[CacheProvider] = [] +_providers_lock = threading.Lock() +_providers_snapshot: Tuple[CacheProvider, ...] = () + + +def register_cache_provider(provider: CacheProvider) -> None: + """Register an external cache provider. Providers are called in registration order.""" + global _providers_snapshot + with _providers_lock: + if provider in _providers: + _logger.warning(f"Provider {provider.__class__.__name__} already registered") + return + _providers.append(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Registered cache provider: {provider.__class__.__name__}") + + +def unregister_cache_provider(provider: CacheProvider) -> None: + global _providers_snapshot + with _providers_lock: + try: + _providers.remove(provider) + _providers_snapshot = tuple(_providers) + _logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}") + except ValueError: + _logger.warning(f"Provider {provider.__class__.__name__} was not registered") + + +def _get_cache_providers() -> Tuple[CacheProvider, ...]: + return _providers_snapshot + + +def _has_cache_providers() -> bool: + return bool(_providers_snapshot) + + +def _clear_cache_providers() -> None: + global _providers_snapshot + with _providers_lock: + _providers.clear() + _providers_snapshot = () + + +def _canonicalize(obj: Any) -> Any: + # Convert to canonical JSON-serializable form with deterministic ordering. + # Frozensets have non-deterministic iteration order between Python sessions. + # Raises ValueError for non-cacheable types (Unhashable, unknown) so that + # _serialize_cache_key returns None and external caching is skipped. + if isinstance(obj, frozenset): + return ("__frozenset__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, set): + return ("__set__", sorted( + [_canonicalize(item) for item in obj], + key=lambda x: json.dumps(x, sort_keys=True) + )) + elif isinstance(obj, tuple): + return ("__tuple__", [_canonicalize(item) for item in obj]) + elif isinstance(obj, list): + return [_canonicalize(item) for item in obj] + elif isinstance(obj, dict): + return {"__dict__": sorted( + [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], + key=lambda x: json.dumps(x, sort_keys=True) + )} + elif isinstance(obj, (int, float, str, bool, type(None))): + return (type(obj).__name__, obj) + elif isinstance(obj, bytes): + return ("__bytes__", obj.hex()) + else: + raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}") + + +def _serialize_cache_key(cache_key: Any) -> Optional[str]: + # Returns deterministic SHA256 hex digest, or None on failure. + # Uses JSON (not pickle) because pickle is non-deterministic across sessions. + try: + canonical = _canonicalize(cache_key) + json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':')) + return hashlib.sha256(json_str.encode('utf-8')).hexdigest() + except Exception as e: + _logger.warning(f"Failed to serialize cache key: {e}") + return None + + +def _contains_self_unequal(obj: Any) -> bool: + # Local cache matches by ==. Values where not (x == x) (NaN, etc.) will + # never hit locally, but serialized form would match externally. Skip these. + try: + if not (obj == obj): + return True + except Exception: + return True + if isinstance(obj, (frozenset, tuple, list, set)): + return any(_contains_self_unequal(item) for item in obj) + if isinstance(obj, dict): + return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items()) + if hasattr(obj, 'value'): + return _contains_self_unequal(obj.value) + return False + + +def _estimate_value_size(value: CacheValue) -> int: + try: + import torch + except ImportError: + return 0 + + total = 0 + + def estimate(obj): + nonlocal total + if isinstance(obj, torch.Tensor): + total += obj.numel() * obj.element_size() + elif isinstance(obj, dict): + for v in obj.values(): + estimate(v) + elif isinstance(obj, (list, tuple)): + for item in obj: + estimate(item) + + for output in value.outputs: + estimate(output) + return total diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 326a279fc..78212bde3 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,3 +1,4 @@ +import asyncio import bisect import gc import itertools @@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet): self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) class BasicCache: - def __init__(self, key_class): + def __init__(self, key_class, enable_providers=False): self.key_class = key_class self.initialized = False + self.enable_providers = enable_providers self.dynprompt: DynamicPrompt self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} + self._pending_store_tasks: set = set() async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt @@ -196,18 +199,138 @@ class BasicCache: def poll(self, **kwargs): pass - def _set_immediate(self, node_id, value): - assert self.initialized - cache_key = self.cache_key_set.get_data_key(node_id) - self.cache[cache_key] = value - - def _get_immediate(self, node_id): + def get_local(self, node_id): if not self.initialized: return None cache_key = self.cache_key_set.get_data_key(node_id) if cache_key in self.cache: return self.cache[cache_key] - else: + return None + + def set_local(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + async def _set_immediate(self, node_id, value): + assert self.initialized + cache_key = self.cache_key_set.get_data_key(node_id) + self.cache[cache_key] = value + + await self._notify_providers_store(node_id, cache_key, value) + + async def _get_immediate(self, node_id): + if not self.initialized: + return None + cache_key = self.cache_key_set.get_data_key(node_id) + + if cache_key in self.cache: + return self.cache[cache_key] + + external_result = await self._check_providers_lookup(node_id, cache_key) + if external_result is not None: + self.cache[cache_key] = external_result + return external_result + + return None + + async def _notify_providers_store(self, node_id, cache_key, value): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not self.enable_providers: + return + if not _has_cache_providers(): + return + if not self._is_external_cacheable_value(value): + return + if _contains_self_unequal(cache_key): + return + + context = self._build_context(node_id, cache_key) + if context is None: + return + cache_value = CacheValue(outputs=value.outputs, ui=value.ui) + + for provider in _get_cache_providers(): + try: + if provider.should_cache(context, cache_value): + task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value)) + self._pending_store_tasks.add(task) + task.add_done_callback(self._pending_store_tasks.discard) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}") + + @staticmethod + async def _safe_provider_store(provider, context, cache_value): + from comfy_execution.cache_provider import _logger + try: + await provider.on_store(context, cache_value) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}") + + async def _check_providers_lookup(self, node_id, cache_key): + from comfy_execution.cache_provider import ( + _has_cache_providers, _get_cache_providers, + CacheValue, _contains_self_unequal, _logger + ) + + if not self.enable_providers: + return None + if not _has_cache_providers(): + return None + if _contains_self_unequal(cache_key): + return None + + context = self._build_context(node_id, cache_key) + if context is None: + return None + + for provider in _get_cache_providers(): + try: + if not provider.should_cache(context): + continue + result = await provider.on_lookup(context) + if result is not None: + if not isinstance(result, CacheValue): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid type") + continue + if not isinstance(result.outputs, (list, tuple)): + _logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs") + continue + from execution import CacheEntry + return CacheEntry(ui=result.ui, outputs=list(result.outputs)) + except Exception as e: + _logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}") + + return None + + def _is_external_cacheable_value(self, value): + return hasattr(value, 'outputs') and hasattr(value, 'ui') + + def _get_class_type(self, node_id): + if not self.initialized or not self.dynprompt: + return '' + try: + return self.dynprompt.get_node(node_id).get('class_type', '') + except Exception: + return '' + + def _build_context(self, node_id, cache_key): + from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger + try: + cache_key_hash = _serialize_cache_key(cache_key) + if cache_key_hash is None: + return None + return CacheContext( + node_id=node_id, + class_type=self._get_class_type(node_id), + cache_key_hash=cache_key_hash, + ) + except Exception as e: + _logger.warning(f"Failed to build cache context for node {node_id}: {e}") return None async def _ensure_subcache(self, node_id, children_ids): @@ -236,8 +359,8 @@ class BasicCache: return result class HierarchicalCache(BasicCache): - def __init__(self, key_class): - super().__init__(key_class) + def __init__(self, key_class, enable_providers=False): + super().__init__(key_class, enable_providers=enable_providers) def _get_cache_for(self, node_id): assert self.dynprompt is not None @@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache): return None return cache - def get(self, node_id): + async def get(self, node_id): cache = self._get_cache_for(node_id) if cache is None: return None - return cache._get_immediate(node_id) + return await cache._get_immediate(node_id) - def set(self, node_id, value): + def get_local(self, node_id): + cache = self._get_cache_for(node_id) + if cache is None: + return None + return BasicCache.get_local(cache, node_id) + + async def set(self, node_id, value): cache = self._get_cache_for(node_id) assert cache is not None - cache._set_immediate(node_id, value) + await cache._set_immediate(node_id, value) + + def set_local(self, node_id, value): + cache = self._get_cache_for(node_id) + assert cache is not None + BasicCache.set_local(cache, node_id, value) async def ensure_subcache_for(self, node_id, children_ids): cache = self._get_cache_for(node_id) @@ -287,18 +421,24 @@ class NullCache: def poll(self, **kwargs): pass - def get(self, node_id): + async def get(self, node_id): return None - def set(self, node_id, value): + def get_local(self, node_id): + return None + + async def set(self, node_id, value): + pass + + def set_local(self, node_id, value): pass async def ensure_subcache_for(self, node_id, children_ids): return self class LRUCache(BasicCache): - def __init__(self, key_class, max_size=100): - super().__init__(key_class) + def __init__(self, key_class, max_size=100, enable_providers=False): + super().__init__(key_class, enable_providers=enable_providers) self.max_size = max_size self.min_generation = 0 self.generation = 0 @@ -322,18 +462,18 @@ class LRUCache(BasicCache): del self.children[key] self._clean_subcaches() - def get(self, node_id): + async def get(self, node_id): self._mark_used(node_id) - return self._get_immediate(node_id) + return await self._get_immediate(node_id) def _mark_used(self, node_id): cache_key = self.cache_key_set.get_data_key(node_id) if cache_key is not None: self.used_generation[cache_key] = self.generation - def set(self, node_id, value): + async def set(self, node_id, value): self._mark_used(node_id) - return self._set_immediate(node_id, value) + return await self._set_immediate(node_id, value) async def ensure_subcache_for(self, node_id, children_ids): # Just uses subcaches for tracking 'live' nodes @@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 class RAMPressureCache(LRUCache): - def __init__(self, key_class): - super().__init__(key_class, 0) + def __init__(self, key_class, enable_providers=False): + super().__init__(key_class, 0, enable_providers=enable_providers) self.timestamps = {} def clean_unused(self): self._clean_subcaches() - def set(self, node_id, value): + async def set(self, node_id, value): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - super().set(node_id, value) + await super().set(node_id, value) - def get(self, node_id): + async def get(self, node_id): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() - return super().get(node_id) + return await super().get(node_id) def poll(self, ram_headroom): def _ram_gb(): diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 0d811e354..c47f3c79b 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -97,6 +97,11 @@ def get_input_info( extra_info = input_info[1] else: extra_info = {} + # if input_type is a list, it is a Combo defined in outdated format; convert it. + # NOTE: uncomment this when we are confident old format going away won't cause too much trouble. + # if isinstance(input_type, list): + # extra_info["options"] = input_type + # input_type = IO.Combo.io_type return input_type, input_category, extra_info class TopologicalSort: @@ -199,24 +204,24 @@ class ExecutionList(TopologicalSort): self.execution_cache_listeners = {} def is_cached(self, node_id): - return self.output_cache.get(node_id) is not None + return self.output_cache.get_local(node_id) is not None def cache_link(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: self.execution_cache[to_node_id] = {} - self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) - if not from_node_id in self.execution_cache_listeners: + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id) + if from_node_id not in self.execution_cache_listeners: self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id].add(to_node_id) def get_cache(self, from_node_id, to_node_id): - if not to_node_id in self.execution_cache: + if to_node_id not in self.execution_cache: return None value = self.execution_cache[to_node_id].get(from_node_id) if value is None: return None #Write back to the main cache on touch. - self.output_cache.set(from_node_id, value) + self.output_cache.set_local(from_node_id, value) return value def cache_update(self, node_id, value): diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py index 59fb49357..fcd7ef735 100644 --- a/comfy_execution/jobs.py +++ b/comfy_execution/jobs.py @@ -14,15 +14,83 @@ class JobStatus: IN_PROGRESS = 'in_progress' COMPLETED = 'completed' FAILED = 'failed' + CANCELLED = 'cancelled' - ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED] + ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED] # Media types that can be previewed in the frontend -PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'}) +PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'}) # 3D file extensions for preview fallback (no dedicated media_type exists) -THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'}) +THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'}) + + +def has_3d_extension(filename: str) -> bool: + lower = filename.lower() + return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS) + + +def normalize_output_item(item): + """Normalize a single output list item for the jobs API. + + Returns the normalized item, or None to exclude it. + String items with 3D extensions become {filename, type, subfolder} dicts. + """ + if item is None: + return None + if isinstance(item, str): + if has_3d_extension(item): + return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'} + return None + if isinstance(item, dict): + return item + return None + + +def normalize_outputs(outputs: dict) -> dict: + """Normalize raw node outputs for the jobs API. + + Transforms string 3D filenames into file output dicts and removes + None items. All other items (non-3D strings, dicts, etc.) are + preserved as-is. + """ + normalized = {} + for node_id, node_outputs in outputs.items(): + if not isinstance(node_outputs, dict): + normalized[node_id] = node_outputs + continue + normalized_node = {} + for media_type, items in node_outputs.items(): + if media_type == 'animated' or not isinstance(items, list): + normalized_node[media_type] = items + continue + normalized_items = [] + for item in items: + if item is None: + continue + norm = normalize_output_item(item) + normalized_items.append(norm if norm is not None else item) + normalized_node[media_type] = normalized_items + normalized[node_id] = normalized_node + return normalized + +# Text preview truncation limit (1024 characters) to prevent preview_output bloat +TEXT_PREVIEW_MAX_LENGTH = 1024 + + +def _create_text_preview(value: str) -> dict: + """Create a text preview dict with optional truncation. + + Returns: + dict with 'content' and optionally 'truncated' flag + """ + if len(value) <= TEXT_PREVIEW_MAX_LENGTH: + return {'content': value} + return { + 'content': value[:TEXT_PREVIEW_MAX_LENGTH], + 'truncated': True + } def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]: @@ -44,9 +112,9 @@ def is_previewable(media_type: str, item: dict) -> bool: Maintains backwards compatibility with existing logic. Priority: - 1. media_type is 'images', 'video', or 'audio' + 1. media_type is 'images', 'video', 'audio', or '3d' 2. format field starts with 'video/' or 'audio/' - 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb) + 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz) """ if media_type in PREVIEWABLE_MEDIA_TYPES: return True @@ -94,12 +162,6 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: status_info = history_item.get('status', {}) status_str = status_info.get('status_str') if status_info else None - if status_str == 'success': - status = JobStatus.COMPLETED - elif status_str == 'error': - status = JobStatus.FAILED - else: - status = JobStatus.COMPLETED outputs = history_item.get('outputs', {}) outputs_count, preview_output = get_outputs_summary(outputs) @@ -107,6 +169,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: execution_error = None execution_start_time = None execution_end_time = None + was_interrupted = False if status_info: messages = status_info.get('messages', []) for entry in messages: @@ -119,6 +182,15 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: execution_end_time = event_data.get('timestamp') if event_name == 'execution_error': execution_error = event_data + elif event_name == 'execution_interrupted': + was_interrupted = True + + if status_str == 'success': + status = JobStatus.COMPLETED + elif status_str == 'error': + status = JobStatus.CANCELLED if was_interrupted else JobStatus.FAILED + else: + status = JobStatus.COMPLETED job = prune_dict({ 'id': prompt_id, @@ -134,7 +206,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs: }) if include_outputs: - job['outputs'] = outputs + job['outputs'] = normalize_outputs(outputs) job['execution_status'] = status_info job['workflow'] = { 'prompt': prompt, @@ -167,15 +239,41 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]: for item in items: if not isinstance(item, dict): - continue + # Handle text outputs (non-dict items like strings or tuples) + normalized = normalize_output_item(item) + if normalized is None: + # Not a 3D file string — check for text preview + if media_type == 'text': + count += 1 + if preview_output is None: + if isinstance(item, tuple): + text_value = item[0] if item else '' + else: + text_value = str(item) + text_preview = _create_text_preview(text_value) + enriched = { + **text_preview, + 'nodeId': node_id, + 'mediaType': media_type + } + if fallback_preview is None: + fallback_preview = enriched + continue + # normalize_output_item returned a dict (e.g. 3D file) + item = normalized + count += 1 - if preview_output is None and is_previewable(media_type, item): + if preview_output is not None: + continue + + if is_previewable(media_type, item): enriched = { **item, 'nodeId': node_id, - 'mediaType': media_type } + if 'mediaType' not in item: + enriched['mediaType'] = media_type if item.get('type') == 'output': preview_output = enriched elif fallback_preview is None: @@ -268,13 +366,13 @@ def get_all_jobs( for item in queued: jobs.append(normalize_queue_item(item, JobStatus.PENDING)) - include_completed = JobStatus.COMPLETED in status_filter - include_failed = JobStatus.FAILED in status_filter - if include_completed or include_failed: + history_statuses = {JobStatus.COMPLETED, JobStatus.FAILED, JobStatus.CANCELLED} + requested_history_statuses = history_statuses & set(status_filter) + if requested_history_statuses: for prompt_id, history_item in history.items(): - is_failed = history_item.get('status', {}).get('status_str') == 'error' - if (is_failed and include_failed) or (not is_failed and include_completed): - jobs.append(normalize_history_item(prompt_id, history_item)) + job = normalize_history_item(prompt_id, history_item) + if job.get('status') in requested_history_statuses: + jobs.append(job) if workflow_id: jobs = [j for j in jobs if j.get('workflow_id') == workflow_id] diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index 24c0b4ed7..e73624bd1 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -21,14 +21,24 @@ def validate_node_input( """ # If the types are exactly the same, we can return immediately # Use pre-union behaviour: inverse of `__ne__` + # NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class. if not received_type != input_type: return True + # If one of the types is '*', we can return True immediately; this is the 'Any' type. + if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type: + return True + # If the received type or input_type is a MatchType, we can return True immediately; # validation for this is handled by the frontend if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: return True + # This accounts for some custom nodes that output lists of options as the type; + # if we ever want to break them on purpose, this can be removed + if isinstance(received_type, list) and input_type == IO.Combo.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False @@ -37,6 +47,10 @@ def validate_node_input( received_types = set(t.strip() for t in received_type.split(",")) input_types = set(t.strip() for t in input_type.split(",")) + # If any of the types is '*', we can return True immediately; this is the 'Any' type. + if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types: + return True + if strict: # In strict mode, all received types must be in the input types return received_types.issubset(input_types) diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 1409233c9..9cf84ab4d 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -28,12 +28,45 @@ class TextEncodeAceStepAudio(io.ComfyNode): conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength}) return io.NodeOutput(conditioning) +class TextEncodeAceStepAudio15(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeAceStepAudio1.5", + category="conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("tags", multiline=True, dynamic_prompts=True), + io.String.Input("lyrics", multiline=True, dynamic_prompts=True), + io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Int.Input("bpm", default=120, min=10, max=300), + io.Float.Input("duration", default=120.0, min=0.0, max=2000.0, step=0.1), + io.Combo.Input("timesignature", options=['2', '3', '4', '6']), + io.Combo.Input("language", options=["en", "ja", "zh", "es", "de", "fr", "pt", "ru", "it", "nl", "pl", "tr", "vi", "cs", "fa", "id", "ko", "uk", "hu", "ar", "sv", "ro", "el"]), + io.Combo.Input("keyscale", options=[f"{root} {quality}" for quality in ["major", "minor"] for root in ["C", "C#", "Db", "D", "D#", "Eb", "E", "F", "F#", "Gb", "G", "G#", "Ab", "A", "A#", "Bb", "B"]]), + io.Boolean.Input("generate_audio_codes", default=True, tooltip="Enable the LLM that generates audio codes. This can be slow but will increase the quality of the generated audio. Turn this off if you are giving the model an audio reference.", advanced=True), + io.Float.Input("cfg_scale", default=2.0, min=0.0, max=100.0, step=0.1, advanced=True), + io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True), + io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True), + io.Int.Input("top_k", default=0, min=0, max=100, advanced=True), + io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput: + tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p) + conditioning = clip.encode_from_tokens_scheduled(tokens) + return io.NodeOutput(conditioning) + class EmptyAceStepLatentAudio(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="EmptyAceStepLatentAudio", + display_name="Empty Ace Step 1.0 Latent Audio", category="latent/audio", inputs=[ io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), @@ -51,12 +84,61 @@ class EmptyAceStepLatentAudio(io.ComfyNode): return io.NodeOutput({"samples": latent, "type": "audio"}) +class EmptyAceStep15LatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyAceStep1.5LatentAudio", + display_name="Empty Ace Step 1.5 Latent Audio", + category="latent/audio", + inputs=[ + io.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), + io.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + ), + ], + outputs=[io.Latent.Output()], + ) + + @classmethod + def execute(cls, seconds, batch_size) -> io.NodeOutput: + length = round((seconds * 48000 / 1920)) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent, "type": "audio"}) + +class ReferenceAudio(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReferenceTimbreAudio", + display_name="Reference Audio", + category="advanced/conditioning/audio", + is_experimental=True, + description="This node sets the reference audio for ace step 1.5", + inputs=[ + io.Conditioning.Input("conditioning"), + io.Latent.Input("latent", optional=True), + ], + outputs=[ + io.Conditioning.Output(), + ] + ) + + @classmethod + def execute(cls, conditioning, latent=None) -> io.NodeOutput: + if latent is not None: + conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_audio_timbre_latents": [latent["samples"]]}, append=True) + return io.NodeOutput(conditioning) + class AceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeAceStepAudio, EmptyAceStepLatentAudio, + TextEncodeAceStepAudio15, + EmptyAceStep15LatentAudio, + ReferenceAudio, ] async def comfy_entrypoint() -> AceExtension: diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 5532ffe6a..7f716cd76 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -47,8 +47,8 @@ class SamplerLCMUpscale(io.ComfyNode): node_id="SamplerLCMUpscale", category="sampling/custom_sampling/samplers", inputs=[ - io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01), - io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1), + io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True), + io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True), io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), ], outputs=[io.Sampler.Output()], @@ -94,7 +94,7 @@ class SamplerEulerCFGpp(io.ComfyNode): display_name="SamplerEulerCFG++", category="_for_testing", # "sampling/custom_sampling/samplers" inputs=[ - io.Combo.Input("version", options=["regular", "alternative"]), + io.Combo.Input("version", options=["regular", "alternative"], advanced=True), ], outputs=[io.Sampler.Output()], is_experimental=True, diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py index edd5dadd4..4fc511d2c 100644 --- a/comfy_extras/nodes_align_your_steps.py +++ b/comfy_extras/nodes_align_your_steps.py @@ -28,6 +28,7 @@ class AlignYourStepsScheduler(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="AlignYourStepsScheduler", + search_aliases=["AYS scheduler"], category="sampling/custom_sampling/schedulers", inputs=[ io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]), diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index f27ae7da8..fd561d360 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -26,6 +26,7 @@ class APG(io.ComfyNode): max=10.0, step=0.01, tooltip="Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1.", + advanced=True, ), io.Float.Input( "norm_threshold", @@ -34,6 +35,7 @@ class APG(io.ComfyNode): max=50.0, step=0.1, tooltip="Normalize guidance vector to this value, normalization disable at a setting of 0.", + advanced=True, ), io.Float.Input( "momentum", @@ -42,6 +44,7 @@ class APG(io.ComfyNode): max=1.0, step=0.01, tooltip="Controls a running average of guidance during diffusion, disabled at a setting of 0.", + advanced=True, ), ], outputs=[io.Model.Output()], @@ -55,7 +58,8 @@ class APG(io.ComfyNode): def pre_cfg_function(args): nonlocal running_avg, prev_sigma - if len(args["conds_out"]) == 1: return args["conds_out"] + if len(args["conds_out"]) == 1: + return args["conds_out"] cond = args["conds_out"][0] uncond = args["conds_out"][1] diff --git a/comfy_extras/nodes_attention_multiply.py b/comfy_extras/nodes_attention_multiply.py index c0e494c2a..060a5c9be 100644 --- a/comfy_extras/nodes_attention_multiply.py +++ b/comfy_extras/nodes_attention_multiply.py @@ -28,10 +28,10 @@ class UNetSelfAttentionMultiply(io.ComfyNode): category="_for_testing/attention_experiments", inputs=[ io.Model.Input("model"), - io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), ], outputs=[io.Model.Output()], is_experimental=True, @@ -51,10 +51,10 @@ class UNetCrossAttentionMultiply(io.ComfyNode): category="_for_testing/attention_experiments", inputs=[ io.Model.Input("model"), - io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), ], outputs=[io.Model.Output()], is_experimental=True, @@ -71,13 +71,14 @@ class CLIPAttentionMultiply(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="CLIPAttentionMultiply", + search_aliases=["clip attention scale", "text encoder attention"], category="_for_testing/attention_experiments", inputs=[ io.Clip.Input("clip"), - io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("k", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("v", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("out", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), ], outputs=[io.Clip.Output()], is_experimental=True, @@ -108,10 +109,10 @@ class UNetTemporalAttentionMultiply(io.ComfyNode): category="_for_testing/attention_experiments", inputs=[ io.Model.Input("model"), - io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01), + io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("self_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("cross_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), + io.Float.Input("cross_temporal", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), ], outputs=[io.Model.Output()], is_experimental=True, diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index c7916443c..a395392d8 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -19,10 +19,11 @@ class EmptyLatentAudio(IO.ComfyNode): node_id="EmptyLatentAudio", display_name="Empty Latent Audio", category="latent/audio", + essentials_category="Audio", inputs=[ IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), IO.Int.Input( - "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch." + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch.", ), ], outputs=[IO.Latent.Output()], @@ -69,6 +70,7 @@ class VAEEncodeAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VAEEncodeAudio", + search_aliases=["audio to latent"], display_name="VAE Encode Audio", category="latent/audio", inputs=[ @@ -81,22 +83,37 @@ class VAEEncodeAudio(IO.ComfyNode): @classmethod def execute(cls, vae, audio) -> IO.NodeOutput: sample_rate = audio["sample_rate"] - if 44100 != sample_rate: - waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + if vae_sample_rate != sample_rate: + waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, vae_sample_rate) else: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return IO.NodeOutput({"samples":t}) + return IO.NodeOutput({"samples": t}) encode = execute # TODO: remove +def vae_decode_audio(vae, samples, tile=None, overlap=None): + if tile is not None: + audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1) + else: + audio = vae.decode(samples["samples"]).movedim(-1, 1) + + std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 + std[std < 1.0] = 1.0 + audio /= std + vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} + + class VAEDecodeAudio(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="VAEDecodeAudio", + search_aliases=["latent to audio"], display_name="VAE Decode Audio", category="latent/audio", inputs=[ @@ -108,22 +125,42 @@ class VAEDecodeAudio(IO.ComfyNode): @classmethod def execute(cls, vae, samples) -> IO.NodeOutput: - audio = vae.decode(samples["samples"]).movedim(-1, 1) - std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 - std[std < 1.0] = 1.0 - audio /= std - return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + return IO.NodeOutput(vae_decode_audio(vae, samples)) decode = execute # TODO: remove +class VAEDecodeAudioTiled(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeAudioTiled", + search_aliases=["latent to audio"], + display_name="VAE Decode Audio (Tiled)", + category="latent/audio", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("tile_size", default=512, min=32, max=8192, step=8), + IO.Int.Input("overlap", default=64, min=0, max=1024, step=8), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, vae, samples, tile_size, overlap) -> IO.NodeOutput: + return IO.NodeOutput(vae_decode_audio(vae, samples, tile_size, overlap)) + + class SaveAudio(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="SaveAudio", + search_aliases=["export flac"], display_name="Save Audio (FLAC)", category="audio", + essentials_category="Audio", inputs=[ IO.Audio.Input("audio"), IO.String.Input("filename_prefix", default="audio/ComfyUI"), @@ -146,8 +183,10 @@ class SaveAudioMP3(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveAudioMP3", + search_aliases=["export mp3"], display_name="Save Audio (MP3)", category="audio", + essentials_category="Audio", inputs=[ IO.Audio.Input("audio"), IO.String.Input("filename_prefix", default="audio/ComfyUI"), @@ -173,6 +212,7 @@ class SaveAudioOpus(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveAudioOpus", + search_aliases=["export opus"], display_name="Save Audio (Opus)", category="audio", inputs=[ @@ -200,6 +240,7 @@ class PreviewAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="PreviewAudio", + search_aliases=["play audio"], display_name="Preview Audio", category="audio", inputs=[ @@ -259,8 +300,10 @@ class LoadAudio(IO.ComfyNode): files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) return IO.Schema( node_id="LoadAudio", + search_aliases=["import audio", "open audio", "audio file"], display_name="Load Audio", category="audio", + essentials_category="Audio", inputs=[ IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)), ], @@ -296,6 +339,7 @@ class RecordAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="RecordAudio", + search_aliases=["microphone input", "audio capture", "voice input"], display_name="Record Audio", category="audio", inputs=[ @@ -320,6 +364,7 @@ class TrimAudioDuration(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TrimAudioDuration", + search_aliases=["cut audio", "audio clip", "shorten audio"], display_name="Trim Audio Duration", description="Trim audio tensor into chosen time range.", category="audio", @@ -372,6 +417,7 @@ class SplitAudioChannels(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SplitAudioChannels", + search_aliases=["stereo to mono"], display_name="Split Audio Channels", description="Separates the audio into left and right channels.", category="audio", @@ -399,6 +445,58 @@ class SplitAudioChannels(IO.ComfyNode): separate = execute # TODO: remove +class JoinAudioChannels(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="JoinAudioChannels", + display_name="Join Audio Channels", + description="Joins left and right mono audio channels into a stereo audio.", + category="audio", + inputs=[ + IO.Audio.Input("audio_left"), + IO.Audio.Input("audio_right"), + ], + outputs=[ + IO.Audio.Output(display_name="audio"), + ], + ) + + @classmethod + def execute(cls, audio_left, audio_right) -> IO.NodeOutput: + waveform_left = audio_left["waveform"] + sample_rate_left = audio_left["sample_rate"] + waveform_right = audio_right["waveform"] + sample_rate_right = audio_right["sample_rate"] + + if waveform_left.shape[1] != 1 or waveform_right.shape[1] != 1: + raise ValueError("AudioJoin: Both input audios must be mono.") + + # Handle different sample rates by resampling to the higher rate + waveform_left, waveform_right, output_sample_rate = match_audio_sample_rates( + waveform_left, sample_rate_left, waveform_right, sample_rate_right + ) + + # Handle different lengths by trimming to the shorter length + length_left = waveform_left.shape[-1] + length_right = waveform_right.shape[-1] + + if length_left != length_right: + min_length = min(length_left, length_right) + if length_left > min_length: + logging.info(f"JoinAudioChannels: Trimming left channel from {length_left} to {min_length} samples.") + waveform_left = waveform_left[..., :min_length] + if length_right > min_length: + logging.info(f"JoinAudioChannels: Trimming right channel from {length_right} to {min_length} samples.") + waveform_right = waveform_right[..., :min_length] + + # Join the channels into stereo + left_channel = waveform_left[..., 0:1, :] + right_channel = waveform_right[..., 0:1, :] + stereo_waveform = torch.cat([left_channel, right_channel], dim=1) + + return IO.NodeOutput({"waveform": stereo_waveform, "sample_rate": output_sample_rate}) + def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): if sample_rate_1 != sample_rate_2: @@ -420,6 +518,7 @@ class AudioConcat(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="AudioConcat", + search_aliases=["join audio", "combine audio", "append audio"], display_name="Audio Concat", description="Concatenates the audio1 to audio2 in the specified direction.", category="audio", @@ -467,6 +566,7 @@ class AudioMerge(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="AudioMerge", + search_aliases=["mix audio", "overlay audio", "layer audio"], display_name="Audio Merge", description="Combine two audio tracks by overlaying their waveforms.", category="audio", @@ -527,6 +627,7 @@ class AudioAdjustVolume(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="AudioAdjustVolume", + search_aliases=["audio gain", "loudness", "audio level"], display_name="Audio Adjust Volume", category="audio", inputs=[ @@ -562,6 +663,7 @@ class EmptyAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="EmptyAudio", + search_aliases=["blank audio"], display_name="Empty Audio", category="audio", inputs=[ @@ -579,6 +681,7 @@ class EmptyAudio(IO.ComfyNode): tooltip="Sample rate of the empty audio clip.", min=1, max=192000, + advanced=True, ), IO.Int.Input( "channels", @@ -586,6 +689,7 @@ class EmptyAudio(IO.ComfyNode): min=1, max=2, tooltip="Number of audio channels (1 for mono, 2 for stereo).", + advanced=True, ), ], outputs=[IO.Audio.Output()], @@ -600,6 +704,67 @@ class EmptyAudio(IO.ComfyNode): create_empty_audio = execute # TODO: remove +class AudioEqualizer3Band(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="AudioEqualizer3Band", + search_aliases=["eq", "bass boost", "treble boost", "equalizer"], + display_name="Audio Equalizer (3-Band)", + category="audio", + is_experimental=True, + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"), + IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"), + IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"), + IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"), + IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"), + IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"), + IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"), + ], + outputs=[IO.Audio.Output()], + ) + + @classmethod + def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput: + waveform = audio["waveform"] + sample_rate = audio["sample_rate"] + eq_waveform = waveform.clone() + + # 1. Apply Low Shelf (Bass) + if low_gain_dB != 0: + eq_waveform = torchaudio.functional.bass_biquad( + eq_waveform, + sample_rate, + gain=low_gain_dB, + central_freq=float(low_freq), + Q=0.707 + ) + + # 2. Apply Peaking EQ (Mids) + if mid_gain_dB != 0: + eq_waveform = torchaudio.functional.equalizer_biquad( + eq_waveform, + sample_rate, + center_freq=float(mid_freq), + gain=mid_gain_dB, + Q=mid_q + ) + + # 3. Apply High Shelf (Treble) + if high_gain_dB != 0: + eq_waveform = torchaudio.functional.treble_biquad( + eq_waveform, + sample_rate, + gain=high_gain_dB, + central_freq=float(high_freq), + Q=0.707 + ) + + return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate}) + + class AudioExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -607,6 +772,7 @@ class AudioExtension(ComfyExtension): EmptyLatentAudio, VAEEncodeAudio, VAEDecodeAudio, + VAEDecodeAudioTiled, SaveAudio, SaveAudioMP3, SaveAudioOpus, @@ -616,10 +782,12 @@ class AudioExtension(ComfyExtension): RecordAudio, TrimAudioDuration, SplitAudioChannels, + JoinAudioChannels, AudioConcat, AudioMerge, AudioAdjustVolume, EmptyAudio, + AudioEqualizer3Band, ] async def comfy_entrypoint() -> AudioExtension: diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py index eb7ef363c..e7efa29ba 100644 --- a/comfy_extras/nodes_camera_trajectory.py +++ b/comfy_extras/nodes_camera_trajectory.py @@ -174,10 +174,10 @@ class WanCameraEmbedding(io.ComfyNode): io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), io.Float.Input("speed", default=1.0, min=0, max=10.0, step=0.1, optional=True), - io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True), - io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True), - io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True), - io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True), + io.Float.Input("fx", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True), + io.Float.Input("fy", default=0.5, min=0, max=1, step=0.000000001, optional=True, advanced=True), + io.Float.Input("cx", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True), + io.Float.Input("cy", default=0.5, min=0, max=1, step=0.01, optional=True, advanced=True), ], outputs=[ io.WanCameraEmbedding.Output(display_name="camera_embedding"), diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 576f3640a..648b4279d 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -3,6 +3,7 @@ from typing_extensions import override import comfy.model_management from comfy_api.latest import ComfyExtension, io +import torch class Canny(io.ComfyNode): @@ -10,7 +11,10 @@ class Canny(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Canny", + display_name="Canny", + search_aliases=["edge detection", "outline", "contour detection", "line art"], category="image/preprocessors", + essentials_category="Image Tools", inputs=[ io.Image.Input("image"), io.Float.Input("low_threshold", default=0.4, min=0.01, max=0.99, step=0.01), @@ -26,8 +30,8 @@ class Canny(io.ComfyNode): @classmethod def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput: - output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) + output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold) + img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1) return io.NodeOutput(img_out) diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py index 381989818..509436062 100644 --- a/comfy_extras/nodes_chroma_radiance.py +++ b/comfy_extras/nodes_chroma_radiance.py @@ -48,6 +48,7 @@ class ChromaRadianceOptions(io.ComfyNode): min=0.0, max=1.0, tooltip="First sigma that these options will be in effect.", + advanced=True, ), io.Float.Input( id="end_sigma", @@ -55,12 +56,14 @@ class ChromaRadianceOptions(io.ComfyNode): min=0.0, max=1.0, tooltip="Last sigma that these options will be in effect.", + advanced=True, ), io.Int.Input( id="nerf_tile_size", default=-1, min=-1, tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", + advanced=True, ), ], outputs=[io.Model.Output()], diff --git a/comfy_extras/nodes_clip_sdxl.py b/comfy_extras/nodes_clip_sdxl.py index 520ff0e3c..7a001af6f 100644 --- a/comfy_extras/nodes_clip_sdxl.py +++ b/comfy_extras/nodes_clip_sdxl.py @@ -35,8 +35,8 @@ class CLIPTextEncodeSDXL(io.ComfyNode): io.Clip.Input("clip"), io.Int.Input("width", default=1024, min=0, max=nodes.MAX_RESOLUTION), io.Int.Input("height", default=1024, min=0, max=nodes.MAX_RESOLUTION), - io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION), - io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION), + io.Int.Input("crop_w", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True), + io.Int.Input("crop_h", default=0, min=0, max=nodes.MAX_RESOLUTION, advanced=True), io.Int.Input("target_width", default=1024, min=0, max=nodes.MAX_RESOLUTION), io.Int.Input("target_height", default=1024, min=0, max=nodes.MAX_RESOLUTION), io.String.Input("text_g", multiline=True, dynamic_prompts=True), diff --git a/comfy_extras/nodes_color.py b/comfy_extras/nodes_color.py new file mode 100644 index 000000000..80ba121cd --- /dev/null +++ b/comfy_extras/nodes_color.py @@ -0,0 +1,42 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class ColorToRGBInt(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ColorToRGBInt", + display_name="Color to RGB Int", + category="utils", + description="Convert a color to a RGB integer value.", + inputs=[ + io.Color.Input("color"), + ], + outputs=[ + io.Int.Output(display_name="rgb_int"), + ], + ) + + @classmethod + def execute( + cls, + color: str, + ) -> io.NodeOutput: + # expect format #RRGGBB + if len(color) != 7 or color[0] != "#": + raise ValueError("Color must be in format #RRGGBB") + r = int(color[1:3], 16) + g = int(color[3:5], 16) + b = int(color[5:7], 16) + return io.NodeOutput(r * 256 * 256 + g * 256 + b) + + +class ColorExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ColorToRGBInt] + + +async def comfy_entrypoint() -> ColorExtension: + return ColorExtension() diff --git a/comfy_extras/nodes_compositing.py b/comfy_extras/nodes_compositing.py index e4e4e1cbc..3bc9fccb3 100644 --- a/comfy_extras/nodes_compositing.py +++ b/comfy_extras/nodes_compositing.py @@ -109,6 +109,7 @@ class PorterDuffImageComposite(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PorterDuffImageComposite", + search_aliases=["alpha composite", "blend modes", "layer blend", "transparency blend"], display_name="Porter-Duff Image Composite", category="mask/compositing", inputs=[ @@ -165,6 +166,7 @@ class SplitImageWithAlpha(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SplitImageWithAlpha", + search_aliases=["extract alpha", "separate transparency", "remove alpha"], display_name="Split Image with Alpha", category="mask/compositing", inputs=[ @@ -188,6 +190,7 @@ class JoinImageWithAlpha(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="JoinImageWithAlpha", + search_aliases=["add transparency", "apply alpha", "composite alpha", "RGBA"], display_name="Join Image with Alpha", category="mask/compositing", inputs=[ diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py index 8b06e3de9..86426a780 100644 --- a/comfy_extras/nodes_cond.py +++ b/comfy_extras/nodes_cond.py @@ -38,8 +38,8 @@ class T5TokenizerOptions(io.ComfyNode): category="_for_testing/conditioning", inputs=[ io.Clip.Input("clip"), - io.Int.Input("min_padding", default=0, min=0, max=10000, step=1), - io.Int.Input("min_length", default=0, min=0, max=10000, step=1), + io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True), + io.Int.Input("min_length", default=0, min=0, max=10000, step=1, advanced=True), ], outputs=[io.Clip.Output()], is_experimental=True, diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 3799a9004..0e43f2e44 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -14,21 +14,21 @@ class ContextWindowsManualNode(io.ComfyNode): description="Manually set context windows.", inputs=[ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window."), - io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window."), + io.Int.Input("context_length", min=1, default=16, tooltip="The length of the context window.", advanced=True), + io.Int.Input("context_overlap", min=0, default=4, tooltip="The overlap of the context window.", advanced=True), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, ], tooltip="The stride of the context window."), - io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), - #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), @@ -67,15 +67,15 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): schema.description = "Manually set context windows for WAN-like models (dim=2)." schema.inputs = [ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), - io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window."), - io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window."), + io.Int.Input("context_length", min=1, max=nodes.MAX_RESOLUTION, step=4, default=81, tooltip="The length of the context window.", advanced=True), + io.Int.Input("context_overlap", min=0, default=30, tooltip="The overlap of the context window.", advanced=True), io.Combo.Input("context_schedule", options=[ comfy.context_windows.ContextSchedules.STATIC_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_STANDARD, comfy.context_windows.ContextSchedules.UNIFORM_LOOPED, comfy.context_windows.ContextSchedules.BATCHED, ], tooltip="The stride of the context window."), - io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), + io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules.", advanced=True), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py index e835feed7..847cb0bdf 100644 --- a/comfy_extras/nodes_controlnet.py +++ b/comfy_extras/nodes_controlnet.py @@ -38,6 +38,7 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ControlNetInpaintingAliMamaApply", + search_aliases=["masked controlnet"], category="conditioning/controlnet", inputs=[ io.Conditioning.Input("positive"), @@ -47,8 +48,8 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode): io.Image.Input("image"), io.Mask.Input("mask"), io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), - io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), diff --git a/comfy_extras/nodes_curve.py b/comfy_extras/nodes_curve.py new file mode 100644 index 000000000..9016a84f9 --- /dev/null +++ b/comfy_extras/nodes_curve.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from comfy_api.latest import ComfyExtension, io +from comfy_api.input import CurveInput +from typing_extensions import override + + +class CurveEditor(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CurveEditor", + display_name="Curve Editor", + category="utils", + inputs=[ + io.Curve.Input("curve"), + io.Histogram.Input("histogram", optional=True), + ], + outputs=[ + io.Curve.Output("curve"), + ], + ) + + @classmethod + def execute(cls, curve, histogram=None) -> io.NodeOutput: + result = CurveInput.from_raw(curve) + + ui = {} + if histogram is not None: + ui["histogram"] = histogram if isinstance(histogram, list) else list(histogram) + + return io.NodeOutput(result, ui=ui) if ui else io.NodeOutput(result) + + +class CurveExtension(ComfyExtension): + @override + async def get_node_list(self): + return [CurveEditor] + + +async def comfy_entrypoint(): + return CurveExtension() diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 7ee4caac1..1e957c09b 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -9,6 +9,7 @@ import comfy.utils import node_helpers from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import re class BasicScheduler(io.ComfyNode): @@ -49,9 +50,9 @@ class KarrasScheduler(io.ComfyNode): category="sampling/custom_sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), - io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), - io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), - io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), ], outputs=[io.Sigmas.Output()] ) @@ -71,8 +72,8 @@ class ExponentialScheduler(io.ComfyNode): category="sampling/custom_sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), - io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), - io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), ], outputs=[io.Sigmas.Output()] ) @@ -92,9 +93,9 @@ class PolyexponentialScheduler(io.ComfyNode): category="sampling/custom_sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), - io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), - io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), - io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), ], outputs=[io.Sigmas.Output()] ) @@ -114,10 +115,10 @@ class LaplaceScheduler(io.ComfyNode): category="sampling/custom_sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), - io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), - io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), - io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False), - io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False, advanced=True), + io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False, advanced=True), ], outputs=[io.Sigmas.Output()] ) @@ -163,8 +164,8 @@ class BetaSamplingScheduler(io.ComfyNode): inputs=[ io.Model.Input("model"), io.Int.Input("steps", default=20, min=1, max=10000), - io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False), - io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True), + io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False, advanced=True), ], outputs=[io.Sigmas.Output()] ) @@ -184,9 +185,9 @@ class VPScheduler(io.ComfyNode): category="sampling/custom_sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), - io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), #TODO: fix default values - io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False), - io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False), + io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), #TODO: fix default values + io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), + io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False, advanced=True), ], outputs=[io.Sigmas.Output()] ) @@ -296,6 +297,7 @@ class ExtendIntermediateSigmas(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ExtendIntermediateSigmas", + search_aliases=["interpolate sigmas"], category="sampling/custom_sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), @@ -396,9 +398,9 @@ class SamplerDPMPP_3M_SDE(io.ComfyNode): node_id="SamplerDPMPP_3M_SDE", category="sampling/custom_sampling/samplers", inputs=[ - io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), - io.Combo.Input("noise_device", options=['gpu', 'cpu']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True), ], outputs=[io.Sampler.Output()] ) @@ -422,9 +424,9 @@ class SamplerDPMPP_2M_SDE(io.ComfyNode): category="sampling/custom_sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=['midpoint', 'heun']), - io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), - io.Combo.Input("noise_device", options=['gpu', 'cpu']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True), ], outputs=[io.Sampler.Output()] ) @@ -448,10 +450,10 @@ class SamplerDPMPP_SDE(io.ComfyNode): node_id="SamplerDPMPP_SDE", category="sampling/custom_sampling/samplers", inputs=[ - io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False), - io.Combo.Input("noise_device", options=['gpu', 'cpu']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Combo.Input("noise_device", options=['gpu', 'cpu'], advanced=True), ], outputs=[io.Sampler.Output()] ) @@ -494,8 +496,8 @@ class SamplerEulerAncestral(io.ComfyNode): node_id="SamplerEulerAncestral", category="sampling/custom_sampling/samplers", inputs=[ - io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), ], outputs=[io.Sampler.Output()] ) @@ -536,7 +538,7 @@ class SamplerLMS(io.ComfyNode): return io.Schema( node_id="SamplerLMS", category="sampling/custom_sampling/samplers", - inputs=[io.Int.Input("order", default=4, min=1, max=100)], + inputs=[io.Int.Input("order", default=4, min=1, max=100, advanced=True)], outputs=[io.Sampler.Output()] ) @@ -554,16 +556,16 @@ class SamplerDPMAdaptative(io.ComfyNode): node_id="SamplerDPMAdaptative", category="sampling/custom_sampling/samplers", inputs=[ - io.Int.Input("order", default=3, min=2, max=3), - io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False), - io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Int.Input("order", default=3, min=2, max=3, advanced=True), + io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), ], outputs=[io.Sampler.Output()] ) @@ -586,9 +588,9 @@ class SamplerER_SDE(io.ComfyNode): category="sampling/custom_sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), - io.Int.Input("max_stage", default=3, min=1, max=3), - io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."), - io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type.", advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), ], outputs=[io.Sampler.Output()] ) @@ -620,17 +622,18 @@ class SamplerSASolver(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerSASolver", + search_aliases=["sde"], category="sampling/custom_sampling/samplers", inputs=[ io.Model.Input("model"), - io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False), - io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001), - io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001), - io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), - io.Int.Input("predictor_order", default=3, min=1, max=6), - io.Int.Input("corrector_order", default=4, min=0, max=6), - io.Boolean.Input("use_pece"), - io.Boolean.Input("simple_order_2"), + io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False, advanced=True), + io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), + io.Int.Input("predictor_order", default=3, min=1, max=6, advanced=True), + io.Int.Input("corrector_order", default=4, min=0, max=6, advanced=True), + io.Boolean.Input("use_pece", advanced=True), + io.Boolean.Input("simple_order_2", advanced=True), ], outputs=[io.Sampler.Output()] ) @@ -664,12 +667,13 @@ class SamplerSEEDS2(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerSEEDS2", + search_aliases=["sde", "exp heun"], category="sampling/custom_sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=["phi_1", "phi_2"]), - io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength"), - io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier"), - io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)"), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength", advanced=True), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="SDE noise multiplier", advanced=True), + io.Float.Input("r", default=0.5, min=0.01, max=1.0, step=0.01, round=False, tooltip="Relative step size for the intermediate stage (c2 node)", advanced=True), ], outputs=[io.Sampler.Output()], description=( @@ -699,7 +703,14 @@ class Noise_EmptyNoise: def generate_noise(self, input_latent): latent_image = input_latent["samples"] - return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") + if latent_image.is_nested: + tensors = latent_image.unbind() + zeros = [] + for t in tensors: + zeros.append(torch.zeros(t.shape, dtype=t.dtype, layout=t.layout, device="cpu")) + return comfy.nested_tensor.NestedTensor(zeros) + else: + return torch.zeros(latent_image.shape, dtype=latent_image.dtype, layout=latent_image.layout, device="cpu") class Noise_RandomNoise: @@ -719,7 +730,7 @@ class SamplerCustom(io.ComfyNode): category="sampling/custom_sampling", inputs=[ io.Model.Input("model"), - io.Boolean.Input("add_noise", default=True), + io.Boolean.Input("add_noise", default=True, advanced=True), io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), io.Conditioning.Input("positive"), @@ -739,7 +750,7 @@ class SamplerCustom(io.ComfyNode): latent = latent_image latent_image = latent["samples"] latent = latent.copy() - latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None)) latent["samples"] = latent_image if not add_noise: @@ -758,10 +769,15 @@ class SamplerCustom(io.ComfyNode): samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples if "x0" in x0_output: + x0_out = model.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -851,6 +867,7 @@ class DualCFGGuider(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="DualCFGGuider", + search_aliases=["dual prompt guidance"], category="sampling/custom_sampling/guiders", inputs=[ io.Model.Input("model"), @@ -878,6 +895,7 @@ class DisableNoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="DisableNoise", + search_aliases=["zero noise"], category="sampling/custom_sampling/noise", inputs=[], outputs=[io.Noise.Output()] @@ -931,7 +949,7 @@ class SamplerCustomAdvanced(io.ComfyNode): latent = latent_image latent_image = latent["samples"] latent = latent.copy() - latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image) + latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None)) latent["samples"] = latent_image noise_mask = None @@ -946,10 +964,15 @@ class SamplerCustomAdvanced(io.ComfyNode): samples = samples.to(comfy.model_management.intermediate_device()) out = latent.copy() + out.pop("downscale_ratio_spacial", None) out["samples"] = samples if "x0" in x0_output: + x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + if samples.is_nested: + latent_shapes = [x.shape for x in samples.unbind()] + x0_out = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(x0_out, latent_shapes)) out_denoised = latent.copy() - out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) + out_denoised["samples"] = x0_out else: out_denoised = out return io.NodeOutput(out, out_denoised) @@ -1005,6 +1028,26 @@ class AddNoise(io.ComfyNode): add_noise = execute +class ManualSigmas(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ManualSigmas", + search_aliases=["custom noise schedule", "define sigmas"], + category="_for_testing/custom_sampling", + is_experimental=True, + inputs=[ + io.String.Input("sigmas", default="1, 0.5", multiline=False) + ], + outputs=[io.Sigmas.Output()] + ) + + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: + sigmas = re.findall(r"[-+]?(?:\d*\.*\d+)", sigmas) + sigmas = [float(i) for i in sigmas] + sigmas = torch.FloatTensor(sigmas) + return io.NodeOutput(sigmas) class CustomSamplersExtension(ComfyExtension): @override @@ -1044,6 +1087,7 @@ class CustomSamplersExtension(ComfyExtension): DisableNoise, AddNoise, SamplerCustomAdvanced, + ManualSigmas, ] diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 513aecf3a..98ed25d7e 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -222,6 +222,7 @@ class SaveImageDataSetToFolderNode(io.ComfyNode): "filename_prefix", default="image", tooltip="Prefix for saved image filenames.", + advanced=True, ), ], outputs=[], @@ -262,6 +263,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode): "filename_prefix", default="image", tooltip="Prefix for saved image filenames.", + advanced=True, ), ], outputs=[], @@ -667,16 +669,19 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode): @classmethod def _process(cls, image, longer_edge): - img = tensor_to_pil(image) - w, h = img.size - if w > h: - new_w = longer_edge - new_h = int(h * (longer_edge / w)) - else: - new_h = longer_edge - new_w = int(w * (longer_edge / h)) - img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) - return pil_to_tensor(img) + resized_images = [] + for image_i in image: + img = tensor_to_pil(image_i) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + resized_images.append(pil_to_tensor(img)) + return torch.cat(resized_images, dim=0) class CenterCropImagesNode(ImageProcessingNode): @@ -738,6 +743,7 @@ class NormalizeImagesNode(ImageProcessingNode): min=0.0, max=1.0, tooltip="Mean value for normalization.", + advanced=True, ), io.Float.Input( "std", @@ -745,6 +751,7 @@ class NormalizeImagesNode(ImageProcessingNode): min=0.001, max=1.0, tooltip="Standard deviation for normalization.", + advanced=True, ), ] @@ -958,6 +965,7 @@ class ImageDeduplicationNode(ImageProcessingNode): min=0.0, max=1.0, tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", + advanced=True, ), ] @@ -1036,6 +1044,7 @@ class ImageGridNode(ImageProcessingNode): min=32, max=2048, tooltip="Width of each cell in the grid.", + advanced=True, ), io.Int.Input( "cell_height", @@ -1043,9 +1052,10 @@ class ImageGridNode(ImageProcessingNode): min=32, max=2048, tooltip="Height of each cell in the grid.", + advanced=True, ), io.Int.Input( - "padding", default=4, min=0, max=50, tooltip="Padding between images." + "padding", default=4, min=0, max=50, tooltip="Padding between images.", advanced=True ), ] @@ -1220,11 +1230,11 @@ class ResolutionBucket(io.ComfyNode): class MakeTrainingDataset(io.ComfyNode): """Encode images with VAE and texts with CLIP to create a training dataset.""" - @classmethod def define_schema(cls): return io.Schema( node_id="MakeTrainingDataset", + search_aliases=["encode dataset"], display_name="Make Training Dataset", category="dataset", is_experimental=True, @@ -1306,11 +1316,11 @@ class MakeTrainingDataset(io.ComfyNode): class SaveTrainingDataset(io.ComfyNode): """Save encoded training dataset (latents + conditioning) to disk.""" - @classmethod def define_schema(cls): return io.Schema( node_id="SaveTrainingDataset", + search_aliases=["export training data"], display_name="Save Training Dataset", category="dataset", is_experimental=True, @@ -1336,6 +1346,7 @@ class SaveTrainingDataset(io.ComfyNode): min=1, max=100000, tooltip="Number of samples per shard file.", + advanced=True, ), ], outputs=[], @@ -1407,11 +1418,11 @@ class SaveTrainingDataset(io.ComfyNode): class LoadTrainingDataset(io.ComfyNode): """Load encoded training dataset from disk.""" - @classmethod def define_schema(cls): return io.Schema( node_id="LoadTrainingDataset", + search_aliases=["import dataset", "training data"], display_name="Load Training Dataset", category="dataset", is_experimental=True, diff --git a/comfy_extras/nodes_differential_diffusion.py b/comfy_extras/nodes_differential_diffusion.py index 6dfdf466c..34ffb9a89 100644 --- a/comfy_extras/nodes_differential_diffusion.py +++ b/comfy_extras/nodes_differential_diffusion.py @@ -11,6 +11,7 @@ class DifferentialDiffusion(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="DifferentialDiffusion", + search_aliases=["inpaint gradient", "variable denoise strength"], display_name="Differential Diffusion", category="_for_testing", inputs=[ diff --git a/comfy_extras/nodes_easycache.py b/comfy_extras/nodes_easycache.py index 11b23ffdb..923c2bb05 100644 --- a/comfy_extras/nodes_easycache.py +++ b/comfy_extras/nodes_easycache.py @@ -9,6 +9,14 @@ if TYPE_CHECKING: from uuid import UUID +def _extract_tensor(data, output_channels): + """Extract tensor from data, handling both single tensors and lists.""" + if isinstance(data, list): + # LTX2 AV tensors: [video, audio] + return data[0][:, :output_channels], data[1][:, :output_channels] + return data[:, :output_channels], None + + def easycache_forward_wrapper(executor, *args, **kwargs): # get values from args transformer_options: dict[str] = args[-1] @@ -17,7 +25,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if not transformer_options: transformer_options = args[-2] easycache: EasyCacheHolder = transformer_options["easycache"] - x: torch.Tensor = args[0][:, :easycache.output_channels] + x, ax = _extract_tensor(args[0], easycache.output_channels) sigmas = transformer_options["sigmas"] uuids = transformer_options["uuids"] if sigmas is not None and easycache.is_past_end_timestep(sigmas): @@ -29,11 +37,17 @@ def easycache_forward_wrapper(executor, *args, **kwargs): do_easycache = easycache.should_do_easycache(sigmas) if do_easycache: easycache.check_metadata(x) + # if there isn't a cache diff for current conds, we cannot skip this step + can_apply_cache_diff = easycache.can_apply_cache_diff(uuids) # if first cond marked this step for skipping, skip it and use appropriate cached values - if easycache.skip_current_step: + if easycache.skip_current_step and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") - return easycache.apply_cache_diff(x, uuids) + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result if easycache.initial_step: easycache.first_cond_uuid = uuids[0] has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) @@ -44,18 +58,23 @@ def easycache_forward_wrapper(executor, *args, **kwargs): if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm easycache.cumulative_change_rate += approx_output_change_rate - if easycache.cumulative_change_rate < easycache.reuse_threshold: + if easycache.cumulative_change_rate < easycache.reuse_threshold and can_apply_cache_diff: if easycache.verbose: logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") # other conds should also skip this step, and instead use their cached values easycache.skip_current_step = True - return easycache.apply_cache_diff(x, uuids) + result = easycache.apply_cache_diff(x, uuids) + if ax is not None: + result_audio = easycache.apply_cache_diff(ax, uuids, is_audio=True) + return [result, result_audio] + return result else: if easycache.verbose: logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") easycache.cumulative_change_rate = 0.0 - output: torch.Tensor = executor(*args, **kwargs) + full_output: torch.Tensor = executor(*args, **kwargs) + output, audio_output = _extract_tensor(full_output, easycache.output_channels) if has_first_cond_uuid and easycache.has_output_prev_norm(): output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() if easycache.verbose: @@ -72,13 +91,15 @@ def easycache_forward_wrapper(executor, *args, **kwargs): logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}") # TODO: allow cache_diff to be offloaded easycache.update_cache_diff(output, next_x_prev, uuids) + if audio_output is not None: + easycache.update_cache_diff(audio_output, ax, uuids, is_audio=True) if has_first_cond_uuid: easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids) easycache.output_prev_subsampled = easycache.subsample(output, uuids) easycache.output_prev_norm = output.flatten().abs().mean() if easycache.verbose: logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") - return output + return full_output def lazycache_predict_noise_wrapper(executor, *args, **kwargs): # get values from args @@ -87,8 +108,8 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs): easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] if easycache.is_past_end_timestep(timestep): return executor(*args, **kwargs) - # prepare next x_prev x: torch.Tensor = args[0][:, :easycache.output_channels] + # prepare next x_prev next_x_prev = x input_change = None do_easycache = easycache.should_do_easycache(timestep) @@ -195,6 +216,7 @@ class EasyCacheHolder: self.output_prev_subsampled: torch.Tensor = None self.output_prev_norm: torch.Tensor = None self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {} + self.uuid_cache_diffs_audio: dict[UUID, torch.Tensor] = {} self.output_change_rates = [] self.approx_output_change_rates = [] self.total_steps_skipped = 0 @@ -240,20 +262,24 @@ class EasyCacheHolder: return to_return.clone() return to_return - def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): - if self.first_cond_uuid in uuids: + def can_apply_cache_diff(self, uuids: list[UUID]) -> bool: + return all(uuid in self.uuid_cache_diffs for uuid in uuids) + + def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + if self.first_cond_uuid in uuids and not is_audio: self.total_steps_skipped += 1 + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs batch_offset = x.shape[0] // len(uuids) for i, uuid in enumerate(uuids): # slice out only what is relevant to this cond batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)] # if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) - if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]: + if x.shape[1:] != cache_diffs[uuid].shape[1:]: if not self.allow_mismatch: raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good") slicing = [] skip_this_dim = True - for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape): + for dim_u, dim_x in zip(cache_diffs[uuid].shape, x.shape): if skip_this_dim: skip_this_dim = False continue @@ -265,10 +291,11 @@ class EasyCacheHolder: else: slicing.append(slice(None)) batch_slice = batch_slice + slicing - x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device) + x[tuple(batch_slice)] += cache_diffs[uuid].to(x.device) return x - def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): + def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID], is_audio: bool = False): + cache_diffs = self.uuid_cache_diffs_audio if is_audio else self.uuid_cache_diffs # if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video) if output.shape[1:] != x.shape[1:]: if not self.allow_mismatch: @@ -288,7 +315,7 @@ class EasyCacheHolder: diff = output - x batch_offset = diff.shape[0] // len(uuids) for i, uuid in enumerate(uuids): - self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] + cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: return self.first_cond_uuid in uuids @@ -319,6 +346,8 @@ class EasyCacheHolder: self.output_prev_norm = None del self.uuid_cache_diffs self.uuid_cache_diffs = {} + del self.uuid_cache_diffs_audio + self.uuid_cache_diffs_audio = {} self.total_steps_skipped = 0 self.state_metadata = None return self @@ -338,10 +367,10 @@ class EasyCacheNode(io.ComfyNode): is_experimental=True, inputs=[ io.Model.Input("model", tooltip="The model to add EasyCache to."), - io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), - io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."), - io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."), - io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache.", advanced=True), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache.", advanced=True), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True), ], outputs=[ io.Model.Output(tooltip="The model with EasyCache."), @@ -471,10 +500,10 @@ class LazyCacheNode(io.ComfyNode): is_experimental=True, inputs=[ io.Model.Input("model", tooltip="The model to add LazyCache to."), - io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), - io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."), - io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."), - io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), + io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps.", advanced=True), + io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache.", advanced=True), + io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache.", advanced=True), + io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information.", advanced=True), ], outputs=[ io.Model.Output(tooltip="The model with LazyCache."), diff --git a/comfy_extras/nodes_eps.py b/comfy_extras/nodes_eps.py index 4d8061741..0fb3871c8 100644 --- a/comfy_extras/nodes_eps.py +++ b/comfy_extras/nodes_eps.py @@ -28,6 +28,7 @@ class EpsilonScaling(io.ComfyNode): max=1.5, step=0.001, display_mode=io.NumberDisplay.number, + advanced=True, ), ], outputs=[ @@ -97,6 +98,7 @@ class TemporalScoreRescaling(io.ComfyNode): max=100.0, step=0.001, display_mode=io.NumberDisplay.number, + advanced=True, ), io.Float.Input( "tsr_sigma", @@ -109,6 +111,7 @@ class TemporalScoreRescaling(io.ComfyNode): max=100.0, step=0.001, display_mode=io.NumberDisplay.number, + advanced=True, ), ], outputs=[ diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 12c8ed3e6..3a23c7d04 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -6,6 +6,7 @@ import comfy.model_management import torch import math import nodes +import comfy.ldm.flux.math class CLIPTextEncodeFlux(io.ComfyNode): @classmethod @@ -161,6 +162,7 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): io.Combo.Input( "reference_latents_method", options=["offset", "index", "uxo/uno", "index_timestep_zero"], + advanced=True, ), ], outputs=[ @@ -230,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode): sigmas = get_schedule(steps, round(seq_len)) return io.NodeOutput(sigmas) +class KV_Attn_Input: + def __init__(self): + self.cache = {} + + def __call__(self, q, k, v, extra_options, **kwargs): + reference_image_num_tokens = extra_options.get("reference_image_num_tokens", []) + if len(reference_image_num_tokens) == 0: + return {} + + ref_toks = sum(reference_image_num_tokens) + cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"]) + if cache_key in self.cache: + kk, vv = self.cache[cache_key] + self.set_cache = False + return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)} + + self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone()) + self.set_cache = True + return {"q": q, "k": k, "v": v} + + def cleanup(self): + self.cache = {} + + +class FluxKVCache(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="FluxKVCache", + display_name="Flux KV Cache", + description="Enables KV Cache optimization for reference images on Flux family models.", + category="", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to use KV Cache on."), + ], + outputs=[ + io.Model.Output(tooltip="The patched model with KV Cache enabled."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type) -> io.NodeOutput: + m = model.clone() + input_patch_obj = KV_Attn_Input() + + def model_input_patch(inputs): + if len(input_patch_obj.cache) > 0: + ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", [])) + if ref_image_tokens > 0: + img = inputs["img"] + inputs["img"] = img[:, :-ref_image_tokens] + return inputs + + m.set_model_attn1_patch(input_patch_obj) + m.set_model_post_input_patch(model_input_patch) + if hasattr(model.model.diffusion_model, "params"): + m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero") + else: + m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero") + + return io.NodeOutput(m) class FluxExtension(ComfyExtension): @override @@ -242,6 +306,7 @@ class FluxExtension(ComfyExtension): FluxKontextMultiReferenceLatentMethod, EmptyFlux2LatentImage, Flux2Scheduler, + FluxKVCache, ] diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 3429b731e..248efdef3 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -32,10 +32,10 @@ class FreeU(IO.ComfyNode): category="model_patches/unet", inputs=[ IO.Model.Input("model"), - IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01), - IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01), - IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), - IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True), ], outputs=[ IO.Model.Output(), @@ -79,10 +79,10 @@ class FreeU_V2(IO.ComfyNode): category="model_patches/unet", inputs=[ IO.Model.Input("model"), - IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01), - IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01), - IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), - IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01, advanced=True), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01, advanced=True), ], outputs=[ IO.Model.Output(), diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py index f308eb0c1..eab4f303f 100644 --- a/comfy_extras/nodes_fresca.py +++ b/comfy_extras/nodes_fresca.py @@ -58,17 +58,18 @@ class FreSca(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FreSca", + search_aliases=["frequency guidance"], display_name="FreSca", category="_for_testing", description="Applies frequency-dependent scaling to the guidance", inputs=[ io.Model.Input("model"), io.Float.Input("scale_low", default=1.0, min=0, max=10, step=0.01, - tooltip="Scaling factor for low-frequency components"), + tooltip="Scaling factor for low-frequency components", advanced=True), io.Float.Input("scale_high", default=1.25, min=0, max=10, step=0.01, - tooltip="Scaling factor for high-frequency components"), + tooltip="Scaling factor for high-frequency components", advanced=True), io.Int.Input("freq_cutoff", default=20, min=1, max=10000, step=1, - tooltip="Number of frequency indices around center to consider as low-frequency"), + tooltip="Number of frequency indices around center to consider as low-frequency", advanced=True), ], outputs=[ io.Model.Output(), diff --git a/comfy_extras/nodes_gits.py b/comfy_extras/nodes_gits.py index 25367560a..d48483862 100644 --- a/comfy_extras/nodes_gits.py +++ b/comfy_extras/nodes_gits.py @@ -342,7 +342,7 @@ class GITSScheduler(io.ComfyNode): node_id="GITSScheduler", category="sampling/custom_sampling/schedulers", inputs=[ - io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05), + io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05, advanced=True), io.Int.Input("steps", default=10, min=2, max=1000), io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), ], diff --git a/comfy_extras/nodes_glsl.py b/comfy_extras/nodes_glsl.py new file mode 100644 index 000000000..2a59a9285 --- /dev/null +++ b/comfy_extras/nodes_glsl.py @@ -0,0 +1,896 @@ +import os +import sys +import re +import logging +import ctypes.util +import importlib.util +from typing import TypedDict + +import numpy as np +import torch + +import nodes +from comfy_api.latest import ComfyExtension, io, ui +from typing_extensions import override +from utils.install_util import get_missing_requirements_message + +logger = logging.getLogger(__name__) + + +def _check_opengl_availability(): + """Early check for OpenGL availability. Raises RuntimeError if unlikely to work.""" + logger.debug("_check_opengl_availability: starting") + missing = [] + + # Check Python packages (using find_spec to avoid importing) + logger.debug("_check_opengl_availability: checking for glfw package") + if importlib.util.find_spec("glfw") is None: + missing.append("glfw") + + logger.debug("_check_opengl_availability: checking for OpenGL package") + if importlib.util.find_spec("OpenGL") is None: + missing.append("PyOpenGL") + + if missing: + raise RuntimeError( + f"OpenGL dependencies not available.\n{get_missing_requirements_message()}\n" + ) + + # On Linux without display, check if headless backends are available + logger.debug(f"_check_opengl_availability: platform={sys.platform}") + if sys.platform.startswith("linux"): + has_display = os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY") + logger.debug(f"_check_opengl_availability: has_display={bool(has_display)}") + if not has_display: + # Check for EGL or OSMesa libraries + logger.debug("_check_opengl_availability: checking for EGL library") + has_egl = ctypes.util.find_library("EGL") + logger.debug("_check_opengl_availability: checking for OSMesa library") + has_osmesa = ctypes.util.find_library("OSMesa") + + # Error disabled for CI as it fails this check + # if not has_egl and not has_osmesa: + # raise RuntimeError( + # "GLSL Shader node: No display and no headless backend (EGL/OSMesa) found.\n" + # "See error below for installation instructions." + # ) + logger.debug(f"Headless mode: EGL={'yes' if has_egl else 'no'}, OSMesa={'yes' if has_osmesa else 'no'}") + + logger.debug("_check_opengl_availability: completed") + + +# Run early check at import time +logger.debug("nodes_glsl: running _check_opengl_availability at import time") +_check_opengl_availability() + +# OpenGL modules - initialized lazily when context is created +gl = None +glfw = None +EGL = None + + +def _import_opengl(): + """Import OpenGL module. Called after context is created.""" + global gl + if gl is None: + logger.debug("_import_opengl: importing OpenGL.GL") + import OpenGL.GL as _gl + gl = _gl + logger.debug("_import_opengl: import completed") + return gl + + +class SizeModeInput(TypedDict): + size_mode: str + width: int + height: int + + +MAX_IMAGES = 5 # u_image0-4 +MAX_UNIFORMS = 5 # u_float0-4, u_int0-4 +MAX_OUTPUTS = 4 # fragColor0-3 (MRT) + +# Vertex shader using gl_VertexID trick - no VBO needed. +# Draws a single triangle that covers the entire screen: +# +# (-1,3) +# /| +# / | <- visible area is the unit square from (-1,-1) to (1,1) +# / | parts outside get clipped away +# (-1,-1)---(3,-1) +# +# v_texCoord is computed from clip space: * 0.5 + 0.5 maps (-1,1) -> (0,1) +VERTEX_SHADER = """#version 330 core +out vec2 v_texCoord; +void main() { + vec2 verts[3] = vec2[](vec2(-1, -1), vec2(3, -1), vec2(-1, 3)); + v_texCoord = verts[gl_VertexID] * 0.5 + 0.5; + gl_Position = vec4(verts[gl_VertexID], 0, 1); +} +""" + +DEFAULT_FRAGMENT_SHADER = """#version 300 es +precision highp float; + +uniform sampler2D u_image0; +uniform vec2 u_resolution; + +in vec2 v_texCoord; +layout(location = 0) out vec4 fragColor0; + +void main() { + fragColor0 = texture(u_image0, v_texCoord); +} +""" + + +def _convert_es_to_desktop(source: str) -> str: + """Convert GLSL ES (WebGL) shader source to desktop GLSL 330 core.""" + # Remove any existing #version directive + source = re.sub(r"#version\s+\d+(\s+es)?\s*\n?", "", source, flags=re.IGNORECASE) + # Remove precision qualifiers (not needed in desktop GLSL) + source = re.sub(r"precision\s+(lowp|mediump|highp)\s+\w+\s*;\s*\n?", "", source) + # Prepend desktop GLSL version + return "#version 330 core\n" + source + + +def _detect_output_count(source: str) -> int: + """Detect how many fragColor outputs are used in the shader. + + Returns the count of outputs needed (1 to MAX_OUTPUTS). + """ + matches = re.findall(r"fragColor(\d+)", source) + if not matches: + return 1 # Default to 1 output if none found + max_index = max(int(m) for m in matches) + return min(max_index + 1, MAX_OUTPUTS) + + +def _detect_pass_count(source: str) -> int: + """Detect multi-pass rendering from #pragma passes N directive. + + Returns the number of passes (1 if not specified). + """ + match = re.search(r'#pragma\s+passes\s+(\d+)', source) + if match: + return max(1, int(match.group(1))) + return 1 + + +def _init_glfw(): + """Initialize GLFW. Returns (window, glfw_module). Raises RuntimeError on failure.""" + logger.debug("_init_glfw: starting") + # On macOS, glfw.init() must be called from main thread or it hangs forever + if sys.platform == "darwin": + logger.debug("_init_glfw: skipping on macOS") + raise RuntimeError("GLFW backend not supported on macOS") + + logger.debug("_init_glfw: importing glfw module") + import glfw as _glfw + + logger.debug("_init_glfw: calling glfw.init()") + if not _glfw.init(): + raise RuntimeError("glfw.init() failed") + + try: + logger.debug("_init_glfw: setting window hints") + _glfw.window_hint(_glfw.VISIBLE, _glfw.FALSE) + _glfw.window_hint(_glfw.CONTEXT_VERSION_MAJOR, 3) + _glfw.window_hint(_glfw.CONTEXT_VERSION_MINOR, 3) + _glfw.window_hint(_glfw.OPENGL_PROFILE, _glfw.OPENGL_CORE_PROFILE) + + logger.debug("_init_glfw: calling create_window()") + window = _glfw.create_window(64, 64, "ComfyUI GLSL", None, None) + if not window: + raise RuntimeError("glfw.create_window() failed") + + logger.debug("_init_glfw: calling make_context_current()") + _glfw.make_context_current(window) + logger.debug("_init_glfw: completed successfully") + return window, _glfw + except Exception: + logger.debug("_init_glfw: failed, terminating glfw") + _glfw.terminate() + raise + + +def _init_egl(): + """Initialize EGL for headless rendering. Returns (display, context, surface, EGL_module). Raises RuntimeError on failure.""" + logger.debug("_init_egl: starting") + from OpenGL import EGL as _EGL + from OpenGL.EGL import ( + eglGetDisplay, eglInitialize, eglChooseConfig, eglCreateContext, + eglMakeCurrent, eglCreatePbufferSurface, eglBindAPI, + eglTerminate, eglDestroyContext, eglDestroySurface, + EGL_DEFAULT_DISPLAY, EGL_NO_CONTEXT, EGL_NONE, + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, + EGL_RED_SIZE, EGL_GREEN_SIZE, EGL_BLUE_SIZE, EGL_ALPHA_SIZE, EGL_DEPTH_SIZE, + EGL_WIDTH, EGL_HEIGHT, EGL_OPENGL_API, + ) + logger.debug("_init_egl: imports completed") + + display = None + context = None + surface = None + + try: + logger.debug("_init_egl: calling eglGetDisplay()") + display = eglGetDisplay(EGL_DEFAULT_DISPLAY) + if display == _EGL.EGL_NO_DISPLAY: + raise RuntimeError("eglGetDisplay() failed") + + logger.debug("_init_egl: calling eglInitialize()") + major, minor = _EGL.EGLint(), _EGL.EGLint() + if not eglInitialize(display, major, minor): + display = None # Not initialized, don't terminate + raise RuntimeError("eglInitialize() failed") + logger.debug(f"_init_egl: EGL version {major.value}.{minor.value}") + + config_attribs = [ + EGL_SURFACE_TYPE, EGL_PBUFFER_BIT, + EGL_RENDERABLE_TYPE, EGL_OPENGL_BIT, + EGL_RED_SIZE, 8, EGL_GREEN_SIZE, 8, EGL_BLUE_SIZE, 8, EGL_ALPHA_SIZE, 8, + EGL_DEPTH_SIZE, 0, EGL_NONE + ] + configs = (_EGL.EGLConfig * 1)() + num_configs = _EGL.EGLint() + if not eglChooseConfig(display, config_attribs, configs, 1, num_configs) or num_configs.value == 0: + raise RuntimeError("eglChooseConfig() failed") + config = configs[0] + logger.debug(f"_init_egl: config chosen, num_configs={num_configs.value}") + + if not eglBindAPI(EGL_OPENGL_API): + raise RuntimeError("eglBindAPI() failed") + + logger.debug("_init_egl: calling eglCreateContext()") + context_attribs = [ + _EGL.EGL_CONTEXT_MAJOR_VERSION, 3, + _EGL.EGL_CONTEXT_MINOR_VERSION, 3, + _EGL.EGL_CONTEXT_OPENGL_PROFILE_MASK, _EGL.EGL_CONTEXT_OPENGL_CORE_PROFILE_BIT, + EGL_NONE + ] + context = eglCreateContext(display, config, EGL_NO_CONTEXT, context_attribs) + if context == EGL_NO_CONTEXT: + raise RuntimeError("eglCreateContext() failed") + + logger.debug("_init_egl: calling eglCreatePbufferSurface()") + pbuffer_attribs = [EGL_WIDTH, 64, EGL_HEIGHT, 64, EGL_NONE] + surface = eglCreatePbufferSurface(display, config, pbuffer_attribs) + if surface == _EGL.EGL_NO_SURFACE: + raise RuntimeError("eglCreatePbufferSurface() failed") + + logger.debug("_init_egl: calling eglMakeCurrent()") + if not eglMakeCurrent(display, surface, surface, context): + raise RuntimeError("eglMakeCurrent() failed") + + logger.debug("_init_egl: completed successfully") + return display, context, surface, _EGL + + except Exception: + logger.debug("_init_egl: failed, cleaning up") + # Clean up any resources on failure + if surface is not None: + eglDestroySurface(display, surface) + if context is not None: + eglDestroyContext(display, context) + if display is not None: + eglTerminate(display) + raise + + +def _init_osmesa(): + """Initialize OSMesa for software rendering. Returns (context, buffer). Raises RuntimeError on failure.""" + import ctypes + + logger.debug("_init_osmesa: starting") + os.environ["PYOPENGL_PLATFORM"] = "osmesa" + + logger.debug("_init_osmesa: importing OpenGL.osmesa") + from OpenGL import GL as _gl + from OpenGL.osmesa import ( + OSMesaCreateContextExt, OSMesaMakeCurrent, OSMesaDestroyContext, + OSMESA_RGBA, + ) + logger.debug("_init_osmesa: imports completed") + + ctx = OSMesaCreateContextExt(OSMESA_RGBA, 24, 0, 0, None) + if not ctx: + raise RuntimeError("OSMesaCreateContextExt() failed") + + width, height = 64, 64 + buffer = (ctypes.c_ubyte * (width * height * 4))() + + logger.debug("_init_osmesa: calling OSMesaMakeCurrent()") + if not OSMesaMakeCurrent(ctx, buffer, _gl.GL_UNSIGNED_BYTE, width, height): + OSMesaDestroyContext(ctx) + raise RuntimeError("OSMesaMakeCurrent() failed") + + logger.debug("_init_osmesa: completed successfully") + return ctx, buffer + + +class GLContext: + """Manages OpenGL context and resources for shader execution. + + Tries backends in order: GLFW (desktop) → EGL (headless GPU) → OSMesa (software). + """ + + _instance = None + _initialized = False + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + if GLContext._initialized: + logger.debug("GLContext.__init__: already initialized, skipping") + return + + logger.debug("GLContext.__init__: starting initialization") + + global glfw, EGL + + import time + start = time.perf_counter() + + self._backend = None + self._window = None + self._egl_display = None + self._egl_context = None + self._egl_surface = None + self._osmesa_ctx = None + self._osmesa_buffer = None + self._vao = None + + # Try backends in order: GLFW → EGL → OSMesa + errors = [] + + logger.debug("GLContext.__init__: trying GLFW backend") + try: + self._window, glfw = _init_glfw() + self._backend = "glfw" + logger.debug("GLContext.__init__: GLFW backend succeeded") + except Exception as e: + logger.debug(f"GLContext.__init__: GLFW backend failed: {e}") + errors.append(("GLFW", e)) + + if self._backend is None: + logger.debug("GLContext.__init__: trying EGL backend") + try: + self._egl_display, self._egl_context, self._egl_surface, EGL = _init_egl() + self._backend = "egl" + logger.debug("GLContext.__init__: EGL backend succeeded") + except Exception as e: + logger.debug(f"GLContext.__init__: EGL backend failed: {e}") + errors.append(("EGL", e)) + + if self._backend is None: + logger.debug("GLContext.__init__: trying OSMesa backend") + try: + self._osmesa_ctx, self._osmesa_buffer = _init_osmesa() + self._backend = "osmesa" + logger.debug("GLContext.__init__: OSMesa backend succeeded") + except Exception as e: + logger.debug(f"GLContext.__init__: OSMesa backend failed: {e}") + errors.append(("OSMesa", e)) + + if self._backend is None: + if sys.platform == "win32": + platform_help = ( + "Windows: Ensure GPU drivers are installed and display is available.\n" + " CPU-only/headless mode is not supported on Windows." + ) + elif sys.platform == "darwin": + platform_help = ( + "macOS: GLFW is not supported.\n" + " Install OSMesa via Homebrew: brew install mesa\n" + " Then: pip install PyOpenGL PyOpenGL-accelerate" + ) + else: + platform_help = ( + "Linux: Install one of these backends:\n" + " Desktop: sudo apt install libgl1-mesa-glx libglfw3\n" + " Headless with GPU: sudo apt install libegl1-mesa libgl1-mesa-dri\n" + " Headless (CPU): sudo apt install libosmesa6" + ) + + error_details = "\n".join(f" {name}: {err}" for name, err in errors) + raise RuntimeError( + f"Failed to create OpenGL context.\n\n" + f"Backend errors:\n{error_details}\n\n" + f"{platform_help}" + ) + + # Now import OpenGL.GL (after context is current) + logger.debug("GLContext.__init__: importing OpenGL.GL") + _import_opengl() + + # Create VAO (required for core profile, but OSMesa may use compat profile) + logger.debug("GLContext.__init__: creating VAO") + try: + vao = gl.glGenVertexArrays(1) + gl.glBindVertexArray(vao) + self._vao = vao # Only store after successful bind + logger.debug("GLContext.__init__: VAO created successfully") + except Exception as e: + logger.debug(f"GLContext.__init__: VAO creation failed (may be expected for OSMesa): {e}") + # OSMesa with older Mesa may not support VAOs + # Clean up if we created but couldn't bind + if vao: + try: + gl.glDeleteVertexArrays(1, [vao]) + except Exception: + pass + + elapsed = (time.perf_counter() - start) * 1000 + + # Log device info + renderer = gl.glGetString(gl.GL_RENDERER) + vendor = gl.glGetString(gl.GL_VENDOR) + version = gl.glGetString(gl.GL_VERSION) + renderer = renderer.decode() if renderer else "Unknown" + vendor = vendor.decode() if vendor else "Unknown" + version = version.decode() if version else "Unknown" + + GLContext._initialized = True + logger.info(f"GLSL context initialized in {elapsed:.1f}ms ({self._backend}) - {renderer} ({vendor}), GL {version}") + + def make_current(self): + if self._backend == "glfw": + glfw.make_context_current(self._window) + elif self._backend == "egl": + from OpenGL.EGL import eglMakeCurrent + eglMakeCurrent(self._egl_display, self._egl_surface, self._egl_surface, self._egl_context) + elif self._backend == "osmesa": + from OpenGL.osmesa import OSMesaMakeCurrent + OSMesaMakeCurrent(self._osmesa_ctx, self._osmesa_buffer, gl.GL_UNSIGNED_BYTE, 64, 64) + + if self._vao is not None: + gl.glBindVertexArray(self._vao) + + +def _compile_shader(source: str, shader_type: int) -> int: + """Compile a shader and return its ID.""" + shader = gl.glCreateShader(shader_type) + gl.glShaderSource(shader, source) + gl.glCompileShader(shader) + + if gl.glGetShaderiv(shader, gl.GL_COMPILE_STATUS) != gl.GL_TRUE: + error = gl.glGetShaderInfoLog(shader).decode() + gl.glDeleteShader(shader) + raise RuntimeError(f"Shader compilation failed:\n{error}") + + return shader + + +def _create_program(vertex_source: str, fragment_source: str) -> int: + """Create and link a shader program.""" + vertex_shader = _compile_shader(vertex_source, gl.GL_VERTEX_SHADER) + try: + fragment_shader = _compile_shader(fragment_source, gl.GL_FRAGMENT_SHADER) + except RuntimeError: + gl.glDeleteShader(vertex_shader) + raise + + program = gl.glCreateProgram() + gl.glAttachShader(program, vertex_shader) + gl.glAttachShader(program, fragment_shader) + gl.glLinkProgram(program) + + gl.glDeleteShader(vertex_shader) + gl.glDeleteShader(fragment_shader) + + if gl.glGetProgramiv(program, gl.GL_LINK_STATUS) != gl.GL_TRUE: + error = gl.glGetProgramInfoLog(program).decode() + gl.glDeleteProgram(program) + raise RuntimeError(f"Program linking failed:\n{error}") + + return program + + +def _render_shader_batch( + fragment_code: str, + width: int, + height: int, + image_batches: list[list[np.ndarray]], + floats: list[float], + ints: list[int], +) -> list[list[np.ndarray]]: + """ + Render a fragment shader for multiple batches efficiently. + + Compiles shader once, reuses framebuffer/textures across batches. + Supports multi-pass rendering via #pragma passes N directive. + + Args: + fragment_code: User's fragment shader code + width: Output width + height: Output height + image_batches: List of batches, each batch is a list of input images (H, W, C) float32 [0,1] + floats: List of float uniforms + ints: List of int uniforms + + Returns: + List of batch outputs, each is a list of output images (H, W, 4) float32 [0,1] + """ + import time + start_time = time.perf_counter() + + if not image_batches: + return [] + + ctx = GLContext() + ctx.make_current() + + # Convert from GLSL ES to desktop GLSL 330 + fragment_source = _convert_es_to_desktop(fragment_code) + + # Detect how many outputs the shader actually uses + num_outputs = _detect_output_count(fragment_code) + + # Detect multi-pass rendering + num_passes = _detect_pass_count(fragment_code) + + # Track resources for cleanup + program = None + fbo = None + output_textures = [] + input_textures = [] + ping_pong_textures = [] + ping_pong_fbos = [] + + num_inputs = len(image_batches[0]) + + try: + # Compile shaders (once for all batches) + try: + program = _create_program(VERTEX_SHADER, fragment_source) + except RuntimeError: + logger.error(f"Fragment shader:\n{fragment_source}") + raise + + gl.glUseProgram(program) + + # Create framebuffer with only the needed color attachments + fbo = gl.glGenFramebuffers(1) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) + + draw_buffers = [] + for i in range(num_outputs): + tex = gl.glGenTextures(1) + output_textures.append(tex) + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0 + i, gl.GL_TEXTURE_2D, tex, 0) + draw_buffers.append(gl.GL_COLOR_ATTACHMENT0 + i) + + gl.glDrawBuffers(num_outputs, draw_buffers) + + if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE: + raise RuntimeError("Framebuffer is not complete") + + # Create ping-pong resources for multi-pass rendering + if num_passes > 1: + for _ in range(2): + pp_tex = gl.glGenTextures(1) + ping_pong_textures.append(pp_tex) + gl.glBindTexture(gl.GL_TEXTURE_2D, pp_tex) + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, width, height, 0, gl.GL_RGBA, gl.GL_FLOAT, None) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) + + pp_fbo = gl.glGenFramebuffers(1) + ping_pong_fbos.append(pp_fbo) + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, pp_fbo) + gl.glFramebufferTexture2D(gl.GL_FRAMEBUFFER, gl.GL_COLOR_ATTACHMENT0, gl.GL_TEXTURE_2D, pp_tex, 0) + gl.glDrawBuffers(1, [gl.GL_COLOR_ATTACHMENT0]) + + if gl.glCheckFramebufferStatus(gl.GL_FRAMEBUFFER) != gl.GL_FRAMEBUFFER_COMPLETE: + raise RuntimeError("Ping-pong framebuffer is not complete") + + # Create input textures (reused for all batches) + for i in range(num_inputs): + tex = gl.glGenTextures(1) + input_textures.append(tex) + gl.glActiveTexture(gl.GL_TEXTURE0 + i) + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MIN_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_MAG_FILTER, gl.GL_LINEAR) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_S, gl.GL_CLAMP_TO_EDGE) + gl.glTexParameteri(gl.GL_TEXTURE_2D, gl.GL_TEXTURE_WRAP_T, gl.GL_CLAMP_TO_EDGE) + + loc = gl.glGetUniformLocation(program, f"u_image{i}") + if loc >= 0: + gl.glUniform1i(loc, i) + + # Set static uniforms (once for all batches) + loc = gl.glGetUniformLocation(program, "u_resolution") + if loc >= 0: + gl.glUniform2f(loc, float(width), float(height)) + + for i, v in enumerate(floats): + loc = gl.glGetUniformLocation(program, f"u_float{i}") + if loc >= 0: + gl.glUniform1f(loc, v) + + for i, v in enumerate(ints): + loc = gl.glGetUniformLocation(program, f"u_int{i}") + if loc >= 0: + gl.glUniform1i(loc, v) + + # Get u_pass uniform location for multi-pass + pass_loc = gl.glGetUniformLocation(program, "u_pass") + + gl.glViewport(0, 0, width, height) + gl.glDisable(gl.GL_BLEND) # Ensure no alpha blending - write output directly + + # Process each batch + all_batch_outputs = [] + for images in image_batches: + # Update input textures with this batch's images + for i, img in enumerate(images): + gl.glActiveTexture(gl.GL_TEXTURE0 + i) + gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[i]) + + # Flip vertically for GL coordinates, ensure RGBA + h, w, c = img.shape + if c == 3: + img_upload = np.empty((h, w, 4), dtype=np.float32) + img_upload[:, :, :3] = img[::-1, :, :] + img_upload[:, :, 3] = 1.0 + else: + img_upload = np.ascontiguousarray(img[::-1, :, :]) + + gl.glTexImage2D(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA32F, w, h, 0, gl.GL_RGBA, gl.GL_FLOAT, img_upload) + + if num_passes == 1: + # Single pass - render directly to output FBO + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) + if pass_loc >= 0: + gl.glUniform1i(pass_loc, 0) + gl.glClearColor(0, 0, 0, 0) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) + else: + # Multi-pass rendering with ping-pong + for p in range(num_passes): + is_last_pass = (p == num_passes - 1) + + # Set pass uniform + if pass_loc >= 0: + gl.glUniform1i(pass_loc, p) + + if is_last_pass: + # Last pass renders to the main output FBO + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, fbo) + else: + # Intermediate passes render to ping-pong FBO + target_fbo = ping_pong_fbos[p % 2] + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, target_fbo) + + # Set input texture for this pass + gl.glActiveTexture(gl.GL_TEXTURE0) + if p == 0: + # First pass reads from original input + gl.glBindTexture(gl.GL_TEXTURE_2D, input_textures[0]) + else: + # Subsequent passes read from previous pass output + source_tex = ping_pong_textures[(p - 1) % 2] + gl.glBindTexture(gl.GL_TEXTURE_2D, source_tex) + + gl.glClearColor(0, 0, 0, 0) + gl.glClear(gl.GL_COLOR_BUFFER_BIT) + gl.glDrawArrays(gl.GL_TRIANGLES, 0, 3) + + # Read back outputs for this batch + # (glGetTexImage is synchronous, implicitly waits for rendering) + batch_outputs = [] + for tex in output_textures: + gl.glBindTexture(gl.GL_TEXTURE_2D, tex) + data = gl.glGetTexImage(gl.GL_TEXTURE_2D, 0, gl.GL_RGBA, gl.GL_FLOAT) + img = np.frombuffer(data, dtype=np.float32).reshape(height, width, 4) + batch_outputs.append(img[::-1, :, :].copy()) + + # Pad with black images for unused outputs + black_img = np.zeros((height, width, 4), dtype=np.float32) + for _ in range(num_outputs, MAX_OUTPUTS): + batch_outputs.append(black_img) + + all_batch_outputs.append(batch_outputs) + + elapsed = (time.perf_counter() - start_time) * 1000 + num_batches = len(image_batches) + pass_info = f", {num_passes} passes" if num_passes > 1 else "" + logger.info(f"GLSL shader executed in {elapsed:.1f}ms ({num_batches} batch{'es' if num_batches != 1 else ''}, {width}x{height}{pass_info})") + + return all_batch_outputs + + finally: + # Unbind before deleting + gl.glBindFramebuffer(gl.GL_FRAMEBUFFER, 0) + gl.glUseProgram(0) + + for tex in input_textures: + gl.glDeleteTextures(int(tex)) + for tex in output_textures: + gl.glDeleteTextures(int(tex)) + for tex in ping_pong_textures: + gl.glDeleteTextures(int(tex)) + if fbo is not None: + gl.glDeleteFramebuffers(1, [fbo]) + for pp_fbo in ping_pong_fbos: + gl.glDeleteFramebuffers(1, [pp_fbo]) + if program is not None: + gl.glDeleteProgram(program) + +class GLSLShader(io.ComfyNode): + + @classmethod + def define_schema(cls) -> io.Schema: + image_template = io.Autogrow.TemplatePrefix( + io.Image.Input("image"), + prefix="image", + min=1, + max=MAX_IMAGES, + ) + + float_template = io.Autogrow.TemplatePrefix( + io.Float.Input("float", default=0.0), + prefix="u_float", + min=0, + max=MAX_UNIFORMS, + ) + + int_template = io.Autogrow.TemplatePrefix( + io.Int.Input("int", default=0), + prefix="u_int", + min=0, + max=MAX_UNIFORMS, + ) + + return io.Schema( + node_id="GLSLShader", + display_name="GLSL Shader", + category="image/shader", + description=( + "Apply GLSL ES fragment shaders to images. " + "u_resolution (vec2) is always available." + ), + inputs=[ + io.String.Input( + "fragment_shader", + default=DEFAULT_FRAGMENT_SHADER, + multiline=True, + tooltip="GLSL fragment shader source code (GLSL ES 3.00 / WebGL 2.0 compatible)", + ), + io.DynamicCombo.Input( + "size_mode", + options=[ + io.DynamicCombo.Option("from_input", []), + io.DynamicCombo.Option( + "custom", + [ + io.Int.Input( + "width", + default=512, + min=1, + max=nodes.MAX_RESOLUTION, + ), + io.Int.Input( + "height", + default=512, + min=1, + max=nodes.MAX_RESOLUTION, + ), + ], + ), + ], + tooltip="Output size: 'from_input' uses first input image dimensions, 'custom' allows manual size", + ), + io.Autogrow.Input("images", template=image_template, tooltip=f"Images are available as u_image0-{MAX_IMAGES-1} (sampler2D) in the shader code"), + io.Autogrow.Input("floats", template=float_template, tooltip=f"Floats are available as u_float0-{MAX_UNIFORMS-1} in the shader code"), + io.Autogrow.Input("ints", template=int_template, tooltip=f"Ints are available as u_int0-{MAX_UNIFORMS-1} in the shader code"), + ], + outputs=[ + io.Image.Output(display_name="IMAGE0", tooltip="Available via layout(location = 0) out vec4 fragColor0 in the shader code"), + io.Image.Output(display_name="IMAGE1", tooltip="Available via layout(location = 1) out vec4 fragColor1 in the shader code"), + io.Image.Output(display_name="IMAGE2", tooltip="Available via layout(location = 2) out vec4 fragColor2 in the shader code"), + io.Image.Output(display_name="IMAGE3", tooltip="Available via layout(location = 3) out vec4 fragColor3 in the shader code"), + ], + ) + + @classmethod + def execute( + cls, + fragment_shader: str, + size_mode: SizeModeInput, + images: io.Autogrow.Type, + floats: io.Autogrow.Type = None, + ints: io.Autogrow.Type = None, + **kwargs, + ) -> io.NodeOutput: + image_list = [v for v in images.values() if v is not None] + float_list = ( + [v if v is not None else 0.0 for v in floats.values()] if floats else [] + ) + int_list = [v if v is not None else 0 for v in ints.values()] if ints else [] + + if not image_list: + raise ValueError("At least one input image is required") + + # Determine output dimensions + if size_mode["size_mode"] == "custom": + out_width = size_mode["width"] + out_height = size_mode["height"] + else: + out_height, out_width = image_list[0].shape[1:3] + + batch_size = image_list[0].shape[0] + + # Prepare batches + image_batches = [] + for batch_idx in range(batch_size): + batch_images = [img_tensor[batch_idx].cpu().numpy().astype(np.float32) for img_tensor in image_list] + image_batches.append(batch_images) + + all_batch_outputs = _render_shader_batch( + fragment_shader, + out_width, + out_height, + image_batches, + float_list, + int_list, + ) + + # Collect outputs into tensors + all_outputs = [[] for _ in range(MAX_OUTPUTS)] + for batch_outputs in all_batch_outputs: + for i, out_img in enumerate(batch_outputs): + all_outputs[i].append(torch.from_numpy(out_img)) + + output_tensors = [torch.stack(all_outputs[i], dim=0) for i in range(MAX_OUTPUTS)] + return io.NodeOutput( + *output_tensors, + ui=cls._build_ui_output(image_list, output_tensors[0]), + ) + + @classmethod + def _build_ui_output( + cls, image_list: list[torch.Tensor], output_batch: torch.Tensor + ) -> dict[str, list]: + """Build UI output with input and output images for client-side shader execution.""" + input_images_ui = [] + for img in image_list: + input_images_ui.extend(ui.ImageSaveHelper.save_images( + img, + filename_prefix="GLSLShader_input", + folder_type=io.FolderType.temp, + cls=None, + compress_level=1, + )) + + output_images_ui = ui.ImageSaveHelper.save_images( + output_batch, + filename_prefix="GLSLShader_output", + folder_type=io.FolderType.temp, + cls=None, + compress_level=1, + ) + + return {"input_images": input_images_ui, "images": output_images_ui} + + +class GLSLExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [GLSLShader] + + +async def comfy_entrypoint() -> GLSLExtension: + return GLSLExtension() diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py index eee683ee1..e345fe51d 100644 --- a/comfy_extras/nodes_hidream.py +++ b/comfy_extras/nodes_hidream.py @@ -38,6 +38,7 @@ class CLIPTextEncodeHiDream(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeHiDream", + search_aliases=["hidream prompt"], category="advanced/conditioning", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index 1edc06f3d..056369e86 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -233,8 +233,8 @@ class SetClipHooks: return { "required": { "clip": ("CLIP",), - "apply_to_conds": ("BOOLEAN", {"default": True}), - "schedule_clip": ("BOOLEAN", {"default": False}) + "apply_to_conds": ("BOOLEAN", {"default": True, "advanced": True}), + "schedule_clip": ("BOOLEAN", {"default": False, "advanced": True}) }, "optional": { "hooks": ("HOOKS",) @@ -248,7 +248,7 @@ class SetClipHooks: def apply_hooks(self, clip: CLIP, schedule_clip: bool, apply_to_conds: bool, hooks: comfy.hooks.HookGroup=None): if hooks is not None: - clip = clip.clone() + clip = clip.clone(disable_dynamic=True) if apply_to_conds: clip.apply_hooks_to_conds = hooks clip.patcher.forced_hooks = hooks.clone() @@ -259,6 +259,7 @@ class SetClipHooks: return (clip,) class ConditioningTimestepsRange: + SEARCH_ALIASES = ["prompt scheduling", "timestep segments", "conditioning phases"] NodeId = 'ConditioningTimestepsRange' NodeName = 'Timesteps Range' @classmethod @@ -468,6 +469,7 @@ class SetHookKeyframes: return (hooks,) class CreateHookKeyframe: + SEARCH_ALIASES = ["hook scheduling", "strength animation", "timed hook"] NodeId = 'CreateHookKeyframe' NodeName = 'Create Hook Keyframe' @classmethod @@ -497,6 +499,7 @@ class CreateHookKeyframe: return (prev_hook_kf,) class CreateHookKeyframesInterpolated: + SEARCH_ALIASES = ["ease hook strength", "smooth hook transition", "interpolate keyframes"] NodeId = 'CreateHookKeyframesInterpolated' NodeName = 'Create Hook Keyframes Interp.' @classmethod @@ -509,7 +512,7 @@ class CreateHookKeyframesInterpolated: "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), "keyframes_count": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}), - "print_keyframes": ("BOOLEAN", {"default": False}), + "print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}), }, "optional": { "prev_hook_kf": ("HOOK_KEYFRAMES",), @@ -544,6 +547,7 @@ class CreateHookKeyframesInterpolated: return (prev_hook_kf,) class CreateHookKeyframesFromFloats: + SEARCH_ALIASES = ["batch keyframes", "strength list to keyframes"] NodeId = 'CreateHookKeyframesFromFloats' NodeName = 'Create Hook Keyframes From Floats' @classmethod @@ -553,7 +557,7 @@ class CreateHookKeyframesFromFloats: "floats_strength": ("FLOATS", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "print_keyframes": ("BOOLEAN", {"default": False}), + "print_keyframes": ("BOOLEAN", {"default": False, "advanced": True}), }, "optional": { "prev_hook_kf": ("HOOK_KEYFRAMES",), @@ -618,6 +622,7 @@ class SetModelHooksOnCond: # Combine Hooks #------------------------------------------ class CombineHooks: + SEARCH_ALIASES = ["merge hooks"] NodeId = 'CombineHooks2' NodeName = 'Combine Hooks [2]' @classmethod diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 32be182f1..4ea93a499 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -5,7 +5,9 @@ import comfy.model_management from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler import folder_paths +import json class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -54,7 +56,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): @classmethod def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples":latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8}) generate = execute # TODO: remove @@ -71,7 +73,7 @@ class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: # Using scale factor of 16 instead of 8 latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples": latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 16}) class HunyuanVideo15ImageToVideo(io.ComfyNode): @@ -136,7 +138,7 @@ class HunyuanVideo15SuperResolution(io.ComfyNode): io.Image.Input("start_image", optional=True), io.ClipVisionOutput.Input("clip_vision_output", optional=True), io.Latent.Input("latent"), - io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01), + io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01, advanced=True), ], outputs=[ @@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode): @classmethod def execute(cls, model_name) -> io.NodeOutput: model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path, safe_load=True) + sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True) if "blocks.0.block.0.conv.weight" in sd: config = { @@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode): "global_residual": False, } model_type = "720p" + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) elif "up.0.block.0.conv1.conv.weight" in sd: sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} config = { @@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode): "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), } model_type = "1080p" - - model = HunyuanVideo15SRModel(model_type, config) - model.load_sd(sd) + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + elif "post_upsample_res_blocks.0.conv2.bias" in sd: + config = json.loads(metadata["config"]) + model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32])) + model.load_state_dict(sd) return io.NodeOutput(model) @@ -278,6 +285,7 @@ class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode): min=1, max=512, tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.", + advanced=True, ), ], outputs=[ @@ -306,7 +314,7 @@ class HunyuanImageToVideo(io.ComfyNode): io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4), io.Int.Input("batch_size", default=1, min=1, max=4096), - io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]), + io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"], advanced=True), io.Image.Input("start_image", optional=True), ], outputs=[ @@ -377,7 +385,7 @@ class HunyuanRefinerLatent(io.ComfyNode): io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), io.Latent.Input("latent"), - io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01), + io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01, advanced=True), ], outputs=[ diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index adca14f62..df0c3e4b1 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -106,8 +106,8 @@ class VAEDecodeHunyuan3D(IO.ComfyNode): inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), - IO.Int.Input("num_chunks", default=8000, min=1000, max=500000), - IO.Int.Input("octree_resolution", default=256, min=16, max=512), + IO.Int.Input("num_chunks", default=8000, min=1000, max=500000, advanced=True), + IO.Int.Input("octree_resolution", default=256, min=16, max=512, advanced=True), ], outputs=[ IO.Voxel.Output(), @@ -456,7 +456,7 @@ class VoxelToMesh(IO.ComfyNode): category="3d", inputs=[ IO.Voxel.Input("voxel"), - IO.Combo.Input("algorithm", options=["surface net", "basic"]), + IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True), IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), ], outputs=[ @@ -618,17 +618,32 @@ class SaveGLB(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="SaveGLB", + display_name="Save 3D Model", + search_aliases=["export 3d model", "save mesh"], category="3d", + essentials_category="Basics", is_output_node=True, inputs=[ - IO.Mesh.Input("mesh"), + IO.MultiType.Input( + IO.Mesh.Input("mesh"), + types=[ + IO.File3DGLB, + IO.File3DGLTF, + IO.File3DOBJ, + IO.File3DFBX, + IO.File3DSTL, + IO.File3DUSDZ, + IO.File3DAny, + ], + tooltip="Mesh or 3D file to save", + ), IO.String.Input("filename_prefix", default="mesh/ComfyUI"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo] ) @classmethod - def execute(cls, mesh, filename_prefix) -> IO.NodeOutput: + def execute(cls, mesh: Types.MESH | Types.File3D, filename_prefix: str) -> IO.NodeOutput: full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) results = [] @@ -640,15 +655,27 @@ class SaveGLB(IO.ComfyNode): for x in cls.hidden.extra_pnginfo: metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) - for i in range(mesh.vertices.shape[0]): - f = f"{filename}_{counter:05}_.glb" - save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) + if isinstance(mesh, Types.File3D): + # Handle File3D input - save BytesIO data to output folder + ext = mesh.format or "glb" + f = f"{filename}_{counter:05}_.{ext}" + mesh.save_to(os.path.join(full_output_folder, f)) results.append({ "filename": f, "subfolder": subfolder, "type": "output" }) - counter += 1 + else: + # Handle Mesh input - save vertices and faces as GLB + for i in range(mesh.vertices.shape[0]): + f = f"{filename}_{counter:05}_.glb" + save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) + results.append({ + "filename": f, + "subfolder": subfolder, + "type": "output" + }) + counter += 1 return IO.NodeOutput(ui={"3d": results}) diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index 0ad5e6773..354d96db1 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -30,10 +30,10 @@ class HyperTile(io.ComfyNode): category="model_patches/unet", inputs=[ io.Model.Input("model"), - io.Int.Input("tile_size", default=256, min=1, max=2048), - io.Int.Input("swap_size", default=2, min=1, max=128), - io.Int.Input("max_depth", default=0, min=0, max=10), - io.Boolean.Input("scale_depth", default=False), + io.Int.Input("tile_size", default=256, min=1, max=2048, advanced=True), + io.Int.Input("swap_size", default=2, min=1, max=128, advanced=True), + io.Int.Input("max_depth", default=0, min=0, max=10, advanced=True), + io.Boolean.Input("scale_depth", default=False, advanced=True), ], outputs=[ io.Model.Output(), diff --git a/comfy_extras/nodes_image_compare.py b/comfy_extras/nodes_image_compare.py new file mode 100644 index 000000000..3d943be67 --- /dev/null +++ b/comfy_extras/nodes_image_compare.py @@ -0,0 +1,54 @@ +import nodes + +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension + + +class ImageCompare(IO.ComfyNode): + """Compares two images with a slider interface.""" + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageCompare", + display_name="Image Compare", + description="Compares two images side by side with a slider.", + category="image", + essentials_category="Image Tools", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.Image.Input("image_a", optional=True), + IO.Image.Input("image_b", optional=True), + IO.ImageCompare.Input("compare_view"), + ], + outputs=[], + ) + + @classmethod + def execute(cls, image_a=None, image_b=None, compare_view=None) -> IO.NodeOutput: + result = {"a_images": [], "b_images": []} + + preview_node = nodes.PreviewImage() + + if image_a is not None and len(image_a) > 0: + saved = preview_node.save_images(image_a, "comfy.compare.a") + result["a_images"] = saved["ui"]["images"] + + if image_b is not None and len(image_b) > 0: + saved = preview_node.save_images(image_b, "comfy.compare.b") + result["b_images"] = saved["ui"]["images"] + + return IO.NodeOutput(ui=result) + + +class ImageCompareExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCompare, + ] + + +async def comfy_entrypoint() -> ImageCompareExtension: + return ImageCompareExtension() diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 392aea32c..a8223cf8b 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -2,280 +2,290 @@ from __future__ import annotations import nodes import folder_paths -from comfy.cli_args import args -from PIL import Image -from PIL.PngImagePlugin import PngInfo - -import numpy as np import json import os import re -from io import BytesIO -from inspect import cleandoc +import math import torch import comfy.utils -from comfy.comfy_types import FileLocator, IO from server import PromptServer +from comfy_api.latest import ComfyExtension, IO, UI +from typing_extensions import override + +SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. MAX_RESOLUTION = nodes.MAX_RESOLUTION -class ImageCrop: +class ImageCrop(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "crop" + def define_schema(cls): + return IO.Schema( + node_id="ImageCrop", + search_aliases=["trim"], + display_name="Image Crop (Deprecated)", + category="image/transform", + is_deprecated=True, + essentials_category="Image Tools", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def crop(self, image, width, height, x, y): + @classmethod + def execute(cls, image, width, height, x, y) -> IO.NodeOutput: x = min(x, image.shape[2] - 1) y = min(y, image.shape[1] - 1) to_x = width + x to_y = height + y img = image[:,y:to_y, x:to_x, :] - return (img,) + return IO.NodeOutput(img) -class RepeatImageBatch: + crop = execute # TODO: remove + + +class ImageCropV2(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "amount": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="ImageCropV2", + search_aliases=["trim"], + display_name="Image Crop", + category="image/transform", + essentials_category="Image Tools", + inputs=[ + IO.Image.Input("image"), + IO.BoundingBox.Input("crop_region", component="ImageCrop"), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" + @classmethod + def execute(cls, image, crop_region) -> IO.NodeOutput: + x = crop_region.get("x", 0) + y = crop_region.get("y", 0) + width = crop_region.get("width", 512) + height = crop_region.get("height", 512) - def repeat(self, image, amount): + x = min(x, image.shape[2] - 1) + y = min(y, image.shape[1] - 1) + to_x = width + x + to_y = height + y + img = image[:,y:to_y, x:to_x, :] + return IO.NodeOutput(img, ui=UI.PreviewImage(img)) + + +class BoundingBox(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PrimitiveBoundingBox", + display_name="Bounding Box", + category="utils/primitive", + inputs=[ + IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION), + IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION), + IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION), + IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION), + ], + outputs=[IO.BoundingBox.Output()], + ) + + @classmethod + def execute(cls, x, y, width, height) -> IO.NodeOutput: + return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height}) + + +class RepeatImageBatch(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RepeatImageBatch", + search_aliases=["duplicate image", "clone image"], + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("amount", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) + + @classmethod + def execute(cls, image, amount) -> IO.NodeOutput: s = image.repeat((amount, 1,1,1)) - return (s,) + return IO.NodeOutput(s) -class ImageFromBatch: + repeat = execute # TODO: remove + + +class ImageFromBatch(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), - "length": ("INT", {"default": 1, "min": 1, "max": 4096}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "frombatch" + def define_schema(cls): + return IO.Schema( + node_id="ImageFromBatch", + search_aliases=["select image", "pick from batch", "extract image"], + category="image/batch", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("batch_index", default=0, min=0, max=4095), + IO.Int.Input("length", default=1, min=1, max=4096), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/batch" - - def frombatch(self, image, batch_index, length): + @classmethod + def execute(cls, image, batch_index, length) -> IO.NodeOutput: s_in = image batch_index = min(s_in.shape[0] - 1, batch_index) length = min(s_in.shape[0] - batch_index, length) s = s_in[batch_index:batch_index + length].clone() - return (s,) + return IO.NodeOutput(s) + + frombatch = execute # TODO: remove -class ImageAddNoise: +class ImageAddNoise(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": ("IMAGE",), - "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}), - "strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "repeat" + def define_schema(cls): + return IO.Schema( + node_id="ImageAddNoise", + search_aliases=["film grain"], + category="image", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + IO.Float.Input("strength", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image" - - def repeat(self, image, seed, strength): + @classmethod + def execute(cls, image, seed, strength) -> IO.NodeOutput: generator = torch.manual_seed(seed) s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0) - return (s,) + return IO.NodeOutput(s) -class SaveAnimatedWEBP: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" + repeat = execute # TODO: remove - methods = {"default": 4, "fastest": 0, "slowest": 6} - @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "lossless": ("BOOLEAN", {"default": True}), - "quality": ("INT", {"default": 80, "min": 0, "max": 100}), - "method": (list(s.methods.keys()),), - # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = () - FUNCTION = "save_images" - - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, filename_prefix, lossless, quality, method, num_frames=0, prompt=None, extra_pnginfo=None): - method = self.methods.get(method) - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results: list[FileLocator] = [] - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = pil_images[0].getexif() - if not args.disable_metadata: - if prompt is not None: - metadata[0x0110] = "prompt:{}".format(json.dumps(prompt)) - if extra_pnginfo is not None: - inital_exif = 0x010f - for x in extra_pnginfo: - metadata[inital_exif] = "{}:{}".format(x, json.dumps(extra_pnginfo[x])) - inital_exif -= 1 - - if num_frames == 0: - num_frames = len(pil_images) - - c = len(pil_images) - for i in range(0, c, num_frames): - file = f"{filename}_{counter:05}_.webp" - pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - animated = num_frames != 1 - return { "ui": { "images": results, "animated": (animated,) } } - -class SaveAnimatedPNG: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAnimatedWEBP(IO.ComfyNode): + COMPRESS_METHODS = {"default": 4, "fastest": 0, "slowest": 6} @classmethod - def INPUT_TYPES(s): - return {"required": - {"images": ("IMAGE", ), - "filename_prefix": ("STRING", {"default": "ComfyUI"}), - "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), - "compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedWEBP", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Boolean.Input("lossless", default=True), + IO.Int.Input("quality", default=80, min=0, max=100), + IO.Combo.Input("method", options=list(cls.COMPRESS_METHODS.keys())), + # "num_frames": ("INT", {"default": 0, "min": 0, "max": 8192}), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) - RETURN_TYPES = () - FUNCTION = "save_images" + @classmethod + def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_webp_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + lossless=lossless, + quality=quality, + method=cls.COMPRESS_METHODS.get(method) + ) + ) - OUTPUT_NODE = True - - CATEGORY = "image/animation" - - def save_images(self, images, fps, compress_level, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) - results = list() - pil_images = [] - for image in images: - i = 255. * image.cpu().numpy() - img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) - pil_images.append(img) - - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add(b"comf", "prompt".encode("latin-1", "strict") + b"\0" + json.dumps(prompt).encode("latin-1", "strict"), after_idat=True) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) - - file = f"{filename}_{counter:05}_.png" - pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - - return { "ui": { "images": results, "animated": (True,)} } - -class SVG: - """ - Stores SVG representations via a list of BytesIO objects. - """ - def __init__(self, data: list[BytesIO]): - self.data = data - - def combine(self, other: 'SVG') -> 'SVG': - return SVG(self.data + other.data) - - @staticmethod - def combine_all(svgs: list['SVG']) -> 'SVG': - all_svgs_list: list[BytesIO] = [] - for svg_item in svgs: - all_svgs_list.extend(svg_item.data) - return SVG(all_svgs_list) + save_images = execute # TODO: remove -class ImageStitch: +class SaveAnimatedPNG(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAnimatedPNG", + category="image/animation", + inputs=[ + IO.Image.Input("images"), + IO.String.Input("filename_prefix", default="ComfyUI"), + IO.Float.Input("fps", default=6.0, min=0.01, max=1000.0, step=0.01), + IO.Int.Input("compress_level", default=4, min=0, max=9, advanced=True), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.ImageSaveHelper.get_save_animated_png_ui( + images=images, + filename_prefix=filename_prefix, + cls=cls, + fps=fps, + compress_level=compress_level, + ) + ) + + save_images = execute # TODO: remove + + +class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageStitch", + search_aliases=["combine images", "join images", "concatenate images", "side by side"], + display_name="Image Stitch", + description="Stitches image2 to image1 in the specified direction.\n" + "If image2 is not provided, returns image1 unchanged.\n" + "Optional spacing can be added between images.", + category="image/transform", + inputs=[ + IO.Image.Input("image1"), + IO.Combo.Input("direction", options=["right", "down", "left", "up"], default="right"), + IO.Boolean.Input("match_image_size", default=True), + IO.Int.Input("spacing_width", default=0, min=0, max=1024, step=2, advanced=True), + IO.Combo.Input("spacing_color", options=["white", "black", "red", "green", "blue"], default="white", advanced=True), + IO.Image.Input("image2", optional=True), + ], + outputs=[IO.Image.Output()], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image1": ("IMAGE",), - "direction": (["right", "down", "left", "up"], {"default": "right"}), - "match_image_size": ("BOOLEAN", {"default": True}), - "spacing_width": ( - "INT", - {"default": 0, "min": 0, "max": 1024, "step": 2}, - ), - "spacing_color": ( - ["white", "black", "red", "green", "blue"], - {"default": "white"}, - ), - }, - "optional": { - "image2": ("IMAGE",), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "stitch" - CATEGORY = "image/transform" - DESCRIPTION = """ -Stitches image2 to image1 in the specified direction. -If image2 is not provided, returns image1 unchanged. -Optional spacing can be added between images. -""" - - def stitch( - self, + def execute( + cls, image1, direction, match_image_size, spacing_width, spacing_color, image2=None, - ): + ) -> IO.NodeOutput: if image2 is None: - return (image1,) + return IO.NodeOutput(image1) # Handle batch size differences if image1.shape[0] != image2.shape[0]: @@ -412,36 +422,30 @@ Optional spacing can be added between images. images.insert(1, spacing) concat_dim = 2 if direction in ["left", "right"] else 1 - return (torch.cat(images, dim=concat_dim),) + return IO.NodeOutput(torch.cat(images, dim=concat_dim)) -class ResizeAndPadImage: + stitch = execute # TODO: remove + + +class ResizeAndPadImage(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "image": ("IMAGE",), - "target_width": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "target_height": ("INT", { - "default": 512, - "min": 1, - "max": MAX_RESOLUTION, - "step": 1 - }), - "padding_color": (["white", "black"],), - "interpolation": (["area", "bicubic", "nearest-exact", "bilinear", "lanczos"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ResizeAndPadImage", + search_aliases=["fit to size"], + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("target_width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("target_height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("padding_color", options=["white", "black"], advanced=True), + IO.Combo.Input("interpolation", options=["area", "bicubic", "nearest-exact", "bilinear", "lanczos"], advanced=True), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "resize_and_pad" - CATEGORY = "image/transform" - - def resize_and_pad(self, image, target_width, target_height, padding_color, interpolation): + @classmethod + def execute(cls, image, target_width, target_height, padding_color, interpolation) -> IO.NodeOutput: batch_size, orig_height, orig_width, channels = image.shape scale_w = target_width / orig_width @@ -469,52 +473,47 @@ class ResizeAndPadImage: padded[:, :, y_offset:y_offset + new_height, x_offset:x_offset + new_width] = resized output = padded.permute(0, 2, 3, 1) - return (output,) + return IO.NodeOutput(output) -class SaveSVGNode: - """ - Save SVG files on disk. - """ + resize_and_pad = execute # TODO: remove - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" - RETURN_TYPES = () - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "save_svg" - CATEGORY = "image/save" # Changed - OUTPUT_NODE = True +class SaveSVGNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveSVGNode", + search_aliases=["export vector", "save vector graphics"], + description="Save SVG files on disk.", + category="image/save", + inputs=[ + IO.SVG.Input("svg"), + IO.String.Input( + "filename_prefix", + default="svg/ComfyUI", + tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "svg": ("SVG",), # Changed - "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) - }, - "hidden": { - "prompt": "PROMPT", - "extra_pnginfo": "EXTRA_PNGINFO" - } - } - - def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None): - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results = list() + def execute(cls, svg: IO.SVG.Type, filename_prefix="svg/ComfyUI") -> IO.NodeOutput: + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) + results: list[UI.SavedResult] = [] # Prepare metadata JSON metadata_dict = {} - if prompt is not None: - metadata_dict["prompt"] = prompt - if extra_pnginfo is not None: - metadata_dict.update(extra_pnginfo) + if cls.hidden.prompt is not None: + metadata_dict["prompt"] = cls.hidden.prompt + if cls.hidden.extra_pnginfo is not None: + metadata_dict.update(cls.hidden.extra_pnginfo) # Convert metadata to JSON string metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None + for batch_number, svg_bytes in enumerate(svg.data): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) file = f"{filename_with_batch_num}_{counter:05}_.svg" @@ -544,57 +543,66 @@ class SaveSVGNode: with open(os.path.join(full_output_folder, file), 'wb') as svg_file: svg_file.write(svg_content.encode('utf-8')) - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) + results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 - return { "ui": { "images": results } } + return IO.NodeOutput(ui={"images": results}) -class GetImageSize: + save_svg = execute # TODO: remove + + +class GetImageSize(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GetImageSize", + search_aliases=["dimensions", "resolution", "image info"], + display_name="Get Image Size", + description="Returns width and height of the image, and passes it through unchanged.", + category="image", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + IO.Int.Output(display_name="batch_size"), + ], + hidden=[IO.Hidden.unique_id], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE,), - }, - "hidden": { - "unique_id": "UNIQUE_ID", - } - } - - RETURN_TYPES = (IO.INT, IO.INT, IO.INT) - RETURN_NAMES = ("width", "height", "batch_size") - FUNCTION = "get_size" - - CATEGORY = "image" - DESCRIPTION = """Returns width and height of the image, and passes it through unchanged.""" - - def get_size(self, image, unique_id=None) -> tuple[int, int]: + def execute(cls, image) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] batch_size = image.shape[0] # Send progress text to display size on the node - if unique_id: - PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", unique_id) + if cls.hidden.unique_id: + PromptServer.instance.send_progress_text(f"width: {width}, height: {height}\n batch size: {batch_size}", cls.hidden.unique_id) - return width, height, batch_size + return IO.NodeOutput(width, height, batch_size) -class ImageRotate: + get_size = execute # TODO: remove + + +class ImageRotate(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "rotation": (["none", "90 degrees", "180 degrees", "270 degrees"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "rotate" + def define_schema(cls): + return IO.Schema( + node_id="ImageRotate", + display_name="Image Rotate", + search_aliases=["turn", "flip orientation"], + category="image/transform", + essentials_category="Image Tools", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("rotation", options=["none", "90 degrees", "180 degrees", "270 degrees"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def rotate(self, image, rotation): + @classmethod + def execute(cls, image, rotation) -> IO.NodeOutput: rotate_by = 0 if rotation.startswith("90"): rotate_by = 1 @@ -604,41 +612,57 @@ class ImageRotate: rotate_by = 3 image = torch.rot90(image, k=rotate_by, dims=[2, 1]) - return (image,) + return IO.NodeOutput(image) -class ImageFlip: + rotate = execute # TODO: remove + + +class ImageFlip(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "image": (IO.IMAGE,), - "flip_method": (["x-axis: vertically", "y-axis: horizontally"],), - }} - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "flip" + def define_schema(cls): + return IO.Schema( + node_id="ImageFlip", + search_aliases=["mirror", "reflect"], + category="image/transform", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("flip_method", options=["x-axis: vertically", "y-axis: horizontally"]), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/transform" - - def flip(self, image, flip_method): + @classmethod + def execute(cls, image, flip_method) -> IO.NodeOutput: if flip_method.startswith("x"): image = torch.flip(image, dims=[1]) elif flip_method.startswith("y"): image = torch.flip(image, dims=[2]) - return (image,) + return IO.NodeOutput(image) -class ImageScaleToMaxDimension: - upscale_methods = ["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"] + flip = execute # TODO: remove + + +class ImageScaleToMaxDimension(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"image": ("IMAGE",), - "upscale_method": (s.upscale_methods,), - "largest_size": ("INT", {"default": 512, "min": 0, "max": MAX_RESOLUTION, "step": 1})}} - RETURN_TYPES = ("IMAGE",) - FUNCTION = "upscale" + def define_schema(cls): + return IO.Schema( + node_id="ImageScaleToMaxDimension", + category="image/upscaling", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "upscale_method", + options=["area", "lanczos", "bilinear", "nearest-exact", "bilinear", "bicubic"], + ), + IO.Int.Input("largest_size", default=512, min=0, max=MAX_RESOLUTION, step=1), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image/upscaling" - - def upscale(self, image, upscale_method, largest_size): + @classmethod + def execute(cls, image, upscale_method, largest_size) -> IO.NodeOutput: height = image.shape[1] width = image.shape[2] @@ -655,20 +679,172 @@ class ImageScaleToMaxDimension: samples = image.movedim(-1, 1) s = comfy.utils.common_upscale(samples, width, height, upscale_method, "disabled") s = s.movedim(1, -1) - return (s,) + return IO.NodeOutput(s) -NODE_CLASS_MAPPINGS = { - "ImageCrop": ImageCrop, - "RepeatImageBatch": RepeatImageBatch, - "ImageFromBatch": ImageFromBatch, - "ImageAddNoise": ImageAddNoise, - "SaveAnimatedWEBP": SaveAnimatedWEBP, - "SaveAnimatedPNG": SaveAnimatedPNG, - "SaveSVGNode": SaveSVGNode, - "ImageStitch": ImageStitch, - "ResizeAndPadImage": ResizeAndPadImage, - "GetImageSize": GetImageSize, - "ImageRotate": ImageRotate, - "ImageFlip": ImageFlip, - "ImageScaleToMaxDimension": ImageScaleToMaxDimension, -} + upscale = execute # TODO: remove + + +class SplitImageToTileList(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplitImageToTileList", + category="image/batch", + search_aliases=["split image", "tile image", "slice image"], + display_name="Split Image into List of Tiles", + description="Splits an image into a batched list of tiles with a specified overlap.", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("tile_width", default=1024, min=64, max=MAX_RESOLUTION), + IO.Int.Input("tile_height", default=1024, min=64, max=MAX_RESOLUTION), + IO.Int.Input("overlap", default=128, min=0, max=4096), + ], + outputs=[ + IO.Image.Output(is_output_list=True), + ], + ) + + @staticmethod + def get_grid_coords(width, height, tile_width, tile_height, overlap): + coords = [] + stride_x = round(max(tile_width * 0.25, tile_width - overlap)) + stride_y = round(max(tile_width * 0.25, tile_height - overlap)) + + y = 0 + while y < height: + x = 0 + y_end = min(y + tile_height, height) + y_start = max(0, y_end - tile_height) + + while x < width: + x_end = min(x + tile_width, width) + x_start = max(0, x_end - tile_width) + + coords.append((x_start, y_start, x_end, y_end)) + + if x_end >= width: + break + x += stride_x + + if y_end >= height: + break + y += stride_y + + return coords + + @classmethod + def execute(cls, image, tile_width, tile_height, overlap): + b, h, w, c = image.shape + coords = cls.get_grid_coords(w, h, tile_width, tile_height, overlap) + + output_list = [] + for (x_start, y_start, x_end, y_end) in coords: + tile = image[:, y_start:y_end, x_start:x_end, :] + output_list.append(tile) + + return IO.NodeOutput(output_list) + + +class ImageMergeTileList(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageMergeTileList", + display_name="Merge List of Tiles to Image", + category="image/batch", + search_aliases=["split image", "tile image", "slice image"], + is_input_list=True, + inputs=[ + IO.Image.Input("image_list"), + IO.Int.Input("final_width", default=1024, min=64, max=32768), + IO.Int.Input("final_height", default=1024, min=64, max=32768), + IO.Int.Input("overlap", default=128, min=0, max=4096), + ], + outputs=[ + IO.Image.Output(is_output_list=False), + ], + ) + + @classmethod + def execute(cls, image_list, final_width, final_height, overlap): + w = final_width[0] + h = final_height[0] + ovlp = overlap[0] + feather_str = 1.0 + + first_tile = image_list[0] + b, t_h, t_w, c = first_tile.shape + device = first_tile.device + dtype = first_tile.dtype + + coords = SplitImageToTileList.get_grid_coords(w, h, t_w, t_h, ovlp) + + canvas = torch.zeros((b, h, w, c), device=device, dtype=dtype) + weights = torch.zeros((b, h, w, 1), device=device, dtype=dtype) + + if ovlp > 0: + y_w = torch.sin(math.pi * torch.linspace(0, 1, t_h, device=device, dtype=dtype)) + x_w = torch.sin(math.pi * torch.linspace(0, 1, t_w, device=device, dtype=dtype)) + y_w = torch.clamp(y_w, min=1e-5) + x_w = torch.clamp(x_w, min=1e-5) + + sine_mask = (y_w.unsqueeze(1) * x_w.unsqueeze(0)).unsqueeze(0).unsqueeze(-1) + flat_mask = torch.ones_like(sine_mask) + + weight_mask = torch.lerp(flat_mask, sine_mask, feather_str) + else: + weight_mask = torch.ones((1, t_h, t_w, 1), device=device, dtype=dtype) + + for i, (x_start, y_start, x_end, y_end) in enumerate(coords): + if i >= len(image_list): + break + + tile = image_list[i] + + region_h = y_end - y_start + region_w = x_end - x_start + + real_h = min(region_h, tile.shape[1]) + real_w = min(region_w, tile.shape[2]) + + y_end_actual = y_start + real_h + x_end_actual = x_start + real_w + + tile_crop = tile[:, :real_h, :real_w, :] + mask_crop = weight_mask[:, :real_h, :real_w, :] + + canvas[:, y_start:y_end_actual, x_start:x_end_actual, :] += tile_crop * mask_crop + weights[:, y_start:y_end_actual, x_start:x_end_actual, :] += mask_crop + + weights[weights == 0] = 1.0 + merged_image = canvas / weights + + return IO.NodeOutput(merged_image) + + +class ImagesExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + ImageCrop, + ImageCropV2, + BoundingBox, + RepeatImageBatch, + ImageFromBatch, + ImageAddNoise, + SaveAnimatedWEBP, + SaveAnimatedPNG, + SaveSVGNode, + ImageStitch, + ResizeAndPadImage, + GetImageSize, + ImageRotate, + ImageFlip, + ImageScaleToMaxDimension, + SplitImageToTileList, + ImageMergeTileList, + ] + + +async def comfy_entrypoint() -> ImagesExtension: + return ImagesExtension() diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 9cb234be1..346c50cde 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -104,6 +104,7 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeKandinsky5", + search_aliases=["kandinsky prompt"], category="advanced/conditioning/kandinsky5", inputs=[ io.Clip.Input("clip"), diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 2815c5ffc..8bb368dec 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -21,6 +21,7 @@ class LatentAdd(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentAdd", + search_aliases=["combine latents", "sum latents"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -47,6 +48,7 @@ class LatentSubtract(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentSubtract", + search_aliases=["difference latent", "remove features"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -73,6 +75,7 @@ class LatentMultiply(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentMultiply", + search_aliases=["scale latent", "amplify latent", "latent gain"], category="latent/advanced", inputs=[ io.Latent.Input("samples"), @@ -96,6 +99,7 @@ class LatentInterpolate(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentInterpolate", + search_aliases=["blend latent", "mix latent", "lerp latent", "transition"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -134,6 +138,7 @@ class LatentConcat(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentConcat", + search_aliases=["join latents", "stitch latents"], category="latent/advanced", inputs=[ io.Latent.Input("samples1"), @@ -173,6 +178,7 @@ class LatentCut(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentCut", + search_aliases=["crop latent", "slice latent", "extract region"], category="latent/advanced", inputs=[ io.Latent.Input("samples"), @@ -213,6 +219,7 @@ class LatentCutToBatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentCutToBatch", + search_aliases=["slice to batch", "split latent", "tile latent"], category="latent/advanced", inputs=[ io.Latent.Input("samples"), @@ -254,7 +261,9 @@ class LatentBatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentBatch", + search_aliases=["combine latents", "merge latents", "join latents"], category="latent/batch", + is_deprecated=True, inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), @@ -309,6 +318,7 @@ class LatentApplyOperation(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentApplyOperation", + search_aliases=["transform latent"], category="latent/advanced/operations", is_experimental=True, inputs=[ @@ -364,6 +374,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentOperationTonemapReinhard", + search_aliases=["hdr latent"], category="latent/advanced/operations", is_experimental=True, inputs=[ @@ -380,8 +391,9 @@ class LatentOperationTonemapReinhard(io.ComfyNode): latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None] normalized_latent = latent / latent_vector_magnitude - mean = torch.mean(latent_vector_magnitude, dim=(1,2,3), keepdim=True) - std = torch.std(latent_vector_magnitude, dim=(1,2,3), keepdim=True) + dims = list(range(1, latent_vector_magnitude.ndim)) + mean = torch.mean(latent_vector_magnitude, dim=dims, keepdim=True) + std = torch.std(latent_vector_magnitude, dim=dims, keepdim=True) top = (std * 5 + mean) * multiplier @@ -401,9 +413,9 @@ class LatentOperationSharpen(io.ComfyNode): category="latent/advanced/operations", is_experimental=True, inputs=[ - io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1), - io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1), - io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01), + io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1, advanced=True), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1, advanced=True), + io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01, advanced=True), ], outputs=[ io.LatentOperation.Output(), diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 545588ef8..9112bdd0a 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -1,9 +1,10 @@ import nodes import folder_paths import os +import uuid from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension, InputImpl, UI +from comfy_api.latest import IO, UI, ComfyExtension, InputImpl, Types from pathlib import Path @@ -24,12 +25,13 @@ class Load3D(IO.ComfyNode): files = [ normalize_path(str(file_path.relative_to(base_path))) for file_path in input_path.rglob("*") - if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} + if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl', '.spz', '.splat', '.ply', '.ksplat'} ] return IO.Schema( node_id="Load3D", display_name="Load 3D & Animation", category="3d", + essentials_category="Basics", is_experimental=True, inputs=[ IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model), @@ -44,6 +46,7 @@ class Load3D(IO.ComfyNode): IO.Image.Output(display_name="normal"), IO.Load3DCamera.Output(display_name="camera_info"), IO.Video.Output(display_name="recording_video"), + IO.File3DAny.Output(display_name="model_3d"), ], ) @@ -65,7 +68,8 @@ class Load3D(IO.ComfyNode): video = InputImpl.VideoFromFile(recording_video_path) - return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video) + file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) + return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d) process = execute # TODO: remove @@ -75,23 +79,41 @@ class Preview3D(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="Preview3D", + search_aliases=["view mesh", "3d viewer"], display_name="Preview 3D & Animation", category="3d", is_experimental=True, is_output_node=True, inputs=[ - IO.String.Input("model_file", default="", multiline=False), - IO.Load3DCamera.Input("camera_info", optional=True), - IO.Image.Input("bg_image", optional=True), + IO.MultiType.Input( + IO.String.Input("model_file", default="", multiline=False), + types=[ + IO.File3DGLB, + IO.File3DGLTF, + IO.File3DFBX, + IO.File3DOBJ, + IO.File3DSTL, + IO.File3DUSDZ, + IO.File3DAny, + ], + tooltip="3D model file or path string", + ), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), + IO.Image.Input("bg_image", optional=True, advanced=True), ], outputs=[], ) @classmethod - def execute(cls, model_file, **kwargs) -> IO.NodeOutput: + def execute(cls, model_file: str | Types.File3D, **kwargs) -> IO.NodeOutput: + if isinstance(model_file, Types.File3D): + filename = f"preview3d_{uuid.uuid4().hex}.{model_file.format}" + model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename)) + else: + filename = model_file camera_info = kwargs.get("camera_info", None) bg_image = kwargs.get("bg_image", None) - return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image)) + return IO.NodeOutput(ui=UI.PreviewUI3D(filename, camera_info, bg_image=bg_image)) process = execute # TODO: remove diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 95a6ba788..c066064ac 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -1,8 +1,11 @@ +from __future__ import annotations from typing import TypedDict from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy_api.latest import _io +# sentinel for missing inputs +MISSING = object() class SwitchNode(io.ComfyNode): @@ -14,6 +17,37 @@ class SwitchNode(io.ComfyNode): display_name="Switch", category="logic", is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True), + io.MatchType.Input("on_true", template=template, lazy=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=None, on_true=None): + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def execute(cls, switch, on_true, on_false) -> io.NodeOutput: + return io.NodeOutput(on_true if switch else on_false) + + +class SoftSwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySoftSwitchNode", + display_name="Soft Switch", + category="logic", + is_experimental=True, inputs=[ io.Boolean.Input("switch"), io.MatchType.Input("on_false", template=template, lazy=True, optional=True), @@ -25,14 +59,14 @@ class SwitchNode(io.ComfyNode): ) @classmethod - def check_lazy_status(cls, switch, on_false=..., on_true=...): - # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + def check_lazy_status(cls, switch, on_false=MISSING, on_true=MISSING): + # We use MISSING instead of None, as None is passed for connected-but-unevaluated inputs. # This trick allows us to ignore the value of the switch and still be able to run execute(). # One of the inputs may be missing, in which case we need to evaluate the other input - if on_false is ...: + if on_false is MISSING: return ["on_true"] - if on_true is ...: + if on_true is MISSING: return ["on_false"] # Normal lazy switch operation if switch and on_true is None: @@ -41,22 +75,54 @@ class SwitchNode(io.ComfyNode): return ["on_false"] @classmethod - def validate_inputs(cls, switch, on_false=..., on_true=...): + def validate_inputs(cls, switch, on_false=MISSING, on_true=MISSING): # This check happens before check_lazy_status(), so we can eliminate the case where # both inputs are missing. - if on_false is ... and on_true is ...: + if on_false is MISSING and on_true is MISSING: return "At least one of on_false or on_true must be connected to Switch node" return True @classmethod - def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: - if on_true is ...: + def execute(cls, switch, on_true=MISSING, on_false=MISSING) -> io.NodeOutput: + if on_true is MISSING: return io.NodeOutput(on_false) - if on_false is ...: + if on_false is MISSING: return io.NodeOutput(on_true) return io.NodeOutput(on_true if switch else on_false) +class CustomComboNode(io.ComfyNode): + """ + Frontend node that allows user to write their own options for a combo. + This is here to make sure the node has a backend-representation to avoid some annoyances. + """ + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CustomCombo", + display_name="Custom Combo", + category="utils", + is_experimental=True, + inputs=[io.Combo.Input("choice", options=[])], + outputs=[ + io.String.Output(display_name="STRING"), + io.Int.Output(display_name="INDEX"), + ], + accept_all_inputs=True, + ) + + @classmethod + def validate_inputs(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> bool: + # NOTE: DO NOT DO THIS unless you want to skip validation entirely on the node's inputs. + # I am doing that here because the widgets (besides the combo dropdown) on this node are fully frontend defined. + # I need to skip checking that the chosen combo option is in the options list, since those are defined by the user. + return True + + @classmethod + def execute(cls, choice: io.Combo.Type, index: int = 0, **kwargs) -> io.NodeOutput: + return io.NodeOutput(choice, index) + + class DCTestNode(io.ComfyNode): class DCValues(TypedDict): combo: str @@ -72,14 +138,14 @@ class DCTestNode(io.ComfyNode): display_name="DCTest", category="logic", is_output_node=True, - inputs=[_io.DynamicCombo.Input("combo", options=[ - _io.DynamicCombo.Option("option1", [io.String.Input("string")]), - _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), - _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), - _io.DynamicCombo.Option("option4", [ - _io.DynamicCombo.Input("subcombo", options=[ - _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), - _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + inputs=[io.DynamicCombo.Input("combo", options=[ + io.DynamicCombo.Option("option1", [io.String.Input("string")]), + io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + io.DynamicCombo.Option("option4", [ + io.DynamicCombo.Input("subcombo", options=[ + io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), ]) ])] )], @@ -141,14 +207,67 @@ class AutogrowPrefixTestNode(io.ComfyNode): combined = ",".join([str(x) for x in vals]) return io.NodeOutput(combined) +class ComboOutputTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ComboOptionTestNode", + display_name="ComboOptionTest", + category="logic", + inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]), + io.Combo.Input("combo2", options=["option4", "option5", "option6"])], + outputs=[io.Combo.Output(), io.Combo.Output()], + ) + + @classmethod + def execute(cls, combo: io.Combo.Type, combo2: io.Combo.Type) -> io.NodeOutput: + return io.NodeOutput(combo, combo2) + +class ConvertStringToComboNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ConvertStringToComboNode", + search_aliases=["string to dropdown", "text to combo"], + display_name="Convert String to Combo", + category="logic", + inputs=[io.String.Input("string")], + outputs=[io.Combo.Output()], + ) + + @classmethod + def execute(cls, string: str) -> io.NodeOutput: + return io.NodeOutput(string) + +class InvertBooleanNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="InvertBooleanNode", + search_aliases=["not", "toggle", "negate", "flip boolean"], + display_name="Invert Boolean", + category="logic", + inputs=[io.Boolean.Input("boolean")], + outputs=[io.Boolean.Output()], + ) + + @classmethod + def execute(cls, boolean: bool) -> io.NodeOutput: + return io.NodeOutput(not boolean) + class LogicExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ - # SwitchNode, + SwitchNode, + CustomComboNode, + # SoftSwitchNode, + # ConvertStringToComboNode, # DCTestNode, # AutogrowNamesTestNode, # AutogrowPrefixTestNode, + # ComboOutputTestNode, + # InvertBooleanNode, ] async def comfy_entrypoint() -> LogicExtension: diff --git a/comfy_extras/nodes_lora_debug.py b/comfy_extras/nodes_lora_debug.py new file mode 100644 index 000000000..937a0fbfb --- /dev/null +++ b/comfy_extras/nodes_lora_debug.py @@ -0,0 +1,79 @@ +import folder_paths +import comfy.utils +import comfy.sd + + +class LoraLoaderBypass: + """ + Apply LoRA in bypass mode without modifying base model weights. + + Bypass mode computes: output = base_forward(x) + lora_path(x) + This is useful for training and when model weights are offloaded. + """ + + def __init__(self): + self.loaded_lora = None + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "clip": ("CLIP", {"tooltip": "The CLIP model the LoRA will be applied to."}), + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + "strength_clip": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the CLIP model. This value can be negative."}), + } + } + + RETURN_TYPES = ("MODEL", "CLIP") + OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") + FUNCTION = "load_lora" + + CATEGORY = "loaders" + DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." + EXPERIMENTAL = True + + def load_lora(self, model, clip, lora_name, strength_model, strength_clip): + if strength_model == 0 and strength_clip == 0: + return (model, clip) + + lora_path = folder_paths.get_full_path_or_raise("loras", lora_name) + lora = None + if self.loaded_lora is not None: + if self.loaded_lora[0] == lora_path: + lora = self.loaded_lora[1] + else: + self.loaded_lora = None + + if lora is None: + lora = comfy.utils.load_torch_file(lora_path, safe_load=True) + self.loaded_lora = (lora_path, lora) + + model_lora, clip_lora = comfy.sd.load_bypass_lora_for_models(model, clip, lora, strength_model, strength_clip) + return (model_lora, clip_lora) + + +class LoraLoaderBypassModelOnly(LoraLoaderBypass): + @classmethod + def INPUT_TYPES(s): + return {"required": { "model": ("MODEL",), + "lora_name": (folder_paths.get_filename_list("loras"), ), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01}), + }} + RETURN_TYPES = ("MODEL",) + FUNCTION = "load_lora_model_only" + + def load_lora_model_only(self, model, lora_name, strength_model): + return (self.load_lora(model, None, lora_name, strength_model, 0)[0],) + + +NODE_CLASS_MAPPINGS = { + "LoraLoaderBypass": LoraLoaderBypass, + "LoraLoaderBypassModelOnly": LoraLoaderBypassModelOnly, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "LoraLoaderBypass": "Load LoRA (Bypass) (For debugging)", + "LoraLoaderBypassModelOnly": "Load LoRA (Bypass, Model Only) (for debugging)", +} diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py index a2375cba7..975f90f45 100644 --- a/comfy_extras/nodes_lora_extract.py +++ b/comfy_extras/nodes_lora_extract.py @@ -7,6 +7,7 @@ import logging from enum import Enum from typing_extensions import override from comfy_api.latest import ComfyExtension, io +from tqdm.auto import trange CLAMP_QUANTILE = 0.99 @@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD, "full_diff": LORAType.FULL_DIFF} def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False): - comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) + comfy.model_management.load_models_gpu([model_diff]) sd = model_diff.model_state_dict(filter_prefix=prefix_model) - for k in sd: - if k.endswith(".weight"): + sd_keys = list(sd.keys()) + for index in trange(len(sd_keys), unit="weight"): + k = sd_keys[index] + op_keys = sd_keys[index].rsplit('.', 1) + if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff): + continue + op = comfy.utils.get_attr(model_diff.model, op_keys[0]) + if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False): + weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True) + else: weight_diff = sd[k] + + if op_keys[1] == "weight": if lora_type == LORAType.STANDARD: if weight_diff.ndim < 2: if bias_diff: @@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora elif lora_type == LORAType.FULL_DIFF: output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu() - elif bias_diff and k.endswith(".bias"): - output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu() + elif bias_diff and op_keys[1] == "bias": + output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu() return output_sd class LoraSave(io.ComfyNode): @@ -78,13 +89,14 @@ class LoraSave(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LoraSave", + search_aliases=["export lora"], display_name="Extract and Save Lora", category="_for_testing", inputs=[ io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"), - io.Int.Input("rank", default=8, min=1, max=4096, step=1), - io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())), - io.Boolean.Input("bias_diff", default=True), + io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True), + io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys()), advanced=True), + io.Boolean.Input("bias_diff", default=True, advanced=True), io.Model.Input( "model_diff", tooltip="The ModelSubtract output to be converted to a lora.", diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 50da5f4eb..d7c2e8744 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -3,6 +3,7 @@ import node_helpers import torch import comfy.model_management import comfy.model_sampling +import comfy.samplers import comfy.utils import math import numpy as np @@ -81,6 +82,89 @@ class LTXVImgToVideo(io.ComfyNode): generate = execute # TODO: remove +class LTXVImgToVideoInplace(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVImgToVideoInplace", + category="conditioning/video_models", + inputs=[ + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Latent.Input("latent"), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0), + io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.") + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput: + if bypass: + return (latent,) + + samples = latent["samples"] + _, height_scale_factor, width_scale_factor = ( + vae.downscale_index_formula + ) + + batch, _, latent_frames, latent_height, latent_width = samples.shape + width = latent_width * width_scale_factor + height = latent_height * height_scale_factor + + if image.shape[1] != height or image.shape[2] != width: + pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + else: + pixels = image + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + + samples[:, :, :t.shape[2]] = t + + conditioning_latent_frames_mask = torch.ones( + (batch, 1, latent_frames, 1, 1), + dtype=torch.float32, + device=samples.device, + ) + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength + + return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask}) + + generate = execute # TODO: remove + + +def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0): + """Append a guide_attention_entry to both positive and negative conditioning. + + Each entry tracks one guide reference for per-reference attention control. + Entries are derived independently from each conditioning to avoid cross-contamination. + """ + new_entry = { + "pre_filter_count": pre_filter_count, + "strength": strength, + "pixel_mask": None, + "latent_shape": latent_shape, + } + results = [] + for cond in (positive, negative): + # Read existing entries from this specific conditioning + existing = [] + for t in cond: + found = t[1].get("guide_attention_entries", None) + if found is not None: + existing = found + break + # Shallow copy and append (no deepcopy needed — entries contain + # only scalars and None for pixel_mask at this call site). + entries = [*existing, new_entry] + results.append(node_helpers.conditioning_set_values( + cond, {"guide_attention_entries": entries} + )) + return results[0], results[1] + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -106,12 +190,12 @@ def get_keyframe_idxs(cond): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 - num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] + # keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start + num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0] return keyframe_idxs, num_keyframes class LTXVAddGuide(io.ComfyNode): - NUM_PREFIX_FRAMES = 2 - PATCHIFIER = SymmetricPatchifier(1) + PATCHIFIER = SymmetricPatchifier(1, start_end=True) @classmethod def define_schema(cls): @@ -170,11 +254,26 @@ class LTXVAddGuide(io.ComfyNode): return frame_idx, latent_idx @classmethod - def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors): + def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1, causal_fix=None): keyframe_idxs, _ = get_keyframe_idxs(cond) _, latent_coords = cls.PATCHIFIER.patchify(guiding_latent) - pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0 + if causal_fix is None: + causal_fix = frame_idx == 0 or guiding_latent.shape[2] == 1 + pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=causal_fix) pixel_coords[:, 0] += frame_idx + + # The following adjusts keyframe end positions for small grid IC-LoRA. + # After dilation, the small grid has the same size and position as the large grid, + # but each token encodes a larger image patch. We adjust the end position (not start) + # so that RoPE represents the correct middle point of each token. + # keyframe_idxs dims: (batch, spatial_dim [t,h,w], token_id, [start, end]) + # We only adjust h,w (not t) in dim 1, and only end (not start) in dim 3. + spatial_end_offset = (latent_downscale_factor - 1) * torch.tensor( + scale_factors[1:], + device=pixel_coords.device, + ).view(1, -1, 1, 1) + pixel_coords[:, 1:, :, 1:] += spatial_end_offset.to(pixel_coords.dtype) + if keyframe_idxs is None: keyframe_idxs = pixel_coords else: @@ -182,26 +281,35 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): - _, latent_idx = cls.get_latent_index( - cond=positive, - latent_length=latent_image.shape[2], - guide_length=guiding_latent.shape[2], - frame_idx=frame_idx, - scale_factors=scale_factors, - ) - noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1, causal_fix=None): + if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: + raise ValueError("Adding guide to a combined AV latent is not supported.") - positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) - negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) + positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) + negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix) - mask = torch.full( - (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), - 1.0 - strength, - dtype=noise_mask.dtype, - device=noise_mask.device, - ) + if guide_mask is not None: + target_h = max(noise_mask.shape[3], guide_mask.shape[3]) + target_w = max(noise_mask.shape[4], guide_mask.shape[4]) + if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1: + noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w) + + if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1: + guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w) + mask = guide_mask - strength + else: + mask = torch.full( + (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + # This solves audio video combined latent case where latent_image has audio latent concatenated + # in channel dimension with video latent. The solution is to pad guiding latent accordingly. + if latent_image.shape[1] > guiding_latent.shape[1]: + pad_len = latent_image.shape[1] - guiding_latent.shape[1] + guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0) latent_image = torch.cat([latent_image, guiding_latent], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask @@ -238,31 +346,22 @@ class LTXVAddGuide(io.ComfyNode): frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) - positive, negative, latent_image, noise_mask = cls.append_keyframe( positive, negative, frame_idx, latent_image, noise_mask, - t[:, :, :num_prefix_frames], + t, strength, scale_factors, ) - latent_idx += num_prefix_frames - - t = t[:, :, num_prefix_frames:] - if t.shape[2] == 0: - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) - - latent_image, noise_mask = cls.replace_latent_frames( - latent_image, - noise_mask, - t, - latent_idx, - strength, + # Track this guide for per-reference attention control. + pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4] + guide_latent_shape = list(t.shape[2:]) # [F, H, W] + positive, negative = _append_guide_attention_entry( + positive, negative, pre_filter_count, guide_latent_shape, strength=strength, ) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) @@ -300,8 +399,14 @@ class LTXVCropGuides(io.ComfyNode): latent_image = latent_image[:, :, :-num_keyframes] noise_mask = noise_mask[:, :, :-num_keyframes] - positive = node_helpers.conditioning_set_values(positive, {"keyframe_idxs": None}) - negative = node_helpers.conditioning_set_values(negative, {"keyframe_idxs": None}) + positive = node_helpers.conditioning_set_values(positive, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) + negative = node_helpers.conditioning_set_values(negative, { + "keyframe_idxs": None, + "guide_attention_entries": None, + }) return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) @@ -391,6 +496,7 @@ class LTXVScheduler(io.ComfyNode): id="stretch", default=True, tooltip="Stretch the sigmas to be in the range [terminal, 1].", + advanced=True, ), io.Float.Input( id="terminal", @@ -399,6 +505,7 @@ class LTXVScheduler(io.ComfyNode): max=0.99, step=0.01, tooltip="The terminal value of the sigmas after stretching.", + advanced=True, ), io.Latent.Input("latent", optional=True), ], @@ -507,18 +614,169 @@ class LTXVPreprocess(io.ComfyNode): preprocess = execute # TODO: remove + +import comfy.nested_tensor +class LTXVConcatAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVConcatAVLatent", + category="latent/video/ltxv", + inputs=[ + io.Latent.Input("video_latent"), + io.Latent.Input("audio_latent"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, video_latent, audio_latent) -> io.NodeOutput: + output = {} + output.update(video_latent) + output.update(audio_latent) + video_noise_mask = video_latent.get("noise_mask", None) + audio_noise_mask = audio_latent.get("noise_mask", None) + + if video_noise_mask is not None or audio_noise_mask is not None: + if video_noise_mask is None: + video_noise_mask = torch.ones_like(video_latent["samples"]) + if audio_noise_mask is None: + audio_noise_mask = torch.ones_like(audio_latent["samples"]) + output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask)) + + output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"])) + + return io.NodeOutput(output) + + +class LTXVSeparateAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVSeparateAVLatent", + category="latent/video/ltxv", + description="LTXV Separate AV Latent", + inputs=[ + io.Latent.Input("av_latent"), + ], + outputs=[ + io.Latent.Output(display_name="video_latent"), + io.Latent.Output(display_name="audio_latent"), + ], + ) + + @classmethod + def execute(cls, av_latent) -> io.NodeOutput: + latents = av_latent["samples"].unbind() + video_latent = av_latent.copy() + video_latent["samples"] = latents[0] + audio_latent = av_latent.copy() + audio_latent["samples"] = latents[1] + if "noise_mask" in av_latent: + masks = av_latent["noise_mask"] + if masks is not None: + masks = masks.unbind() + video_latent["noise_mask"] = masks[0] + audio_latent["noise_mask"] = masks[1] + return io.NodeOutput(video_latent, audio_latent) + + +class LTXVReferenceAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVReferenceAudio", + display_name="LTXV Reference Audio (ID-LoRA)", + category="conditioning/audio", + description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."), + io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."), + io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."), + ], + outputs=[ + io.Model.Output(), + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + ], + ) + + @classmethod + def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: + # Encode reference audio to latents and patchify + audio_latents = audio_vae.encode(reference_audio) + b, c, t, f = audio_latents.shape + ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) + ref_audio = {"tokens": ref_tokens} + + positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio}) + negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio}) + + # Patch model with identity guidance + m = model.clone() + scale = identity_guidance_scale + model_sampling = m.get_model_object("model_sampling") + sigma_start = model_sampling.percent_to_sigma(start_percent) + sigma_end = model_sampling.percent_to_sigma(end_percent) + + def post_cfg_function(args): + if scale == 0: + return args["denoised"] + + sigma = args["sigma"] + sigma_ = sigma[0].item() + if sigma_ > sigma_start or sigma_ < sigma_end: + return args["denoised"] + + cond_pred = args["cond_denoised"] + cond = args["cond"] + cfg_result = args["denoised"] + model_options = args["model_options"].copy() + x = args["input"] + + # Strip ref_audio from conditioning for the no-reference pass + noref_cond = [] + for entry in cond: + new_entry = entry.copy() + mc = new_entry.get("model_conds", {}).copy() + mc.pop("ref_audio", None) + new_entry["model_conds"] = mc + noref_cond.append(new_entry) + + (pred_noref,) = comfy.samplers.calc_cond_batch( + args["model"], [noref_cond], x, sigma, model_options + ) + + return cfg_result + (cond_pred - pred_noref) * scale + + m.set_model_sampler_post_cfg_function(post_cfg_function) + + return io.NodeOutput(m, positive, negative) + + class LtxvExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ EmptyLTXVLatentVideo, LTXVImgToVideo, + LTXVImgToVideoInplace, ModelSamplingLTXV, LTXVConditioning, LTXVScheduler, LTXVAddGuide, LTXVPreprocess, LTXVCropGuides, + LTXVConcatAVLatent, + LTXVSeparateAVLatent, + LTXVReferenceAudio, ] diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py new file mode 100644 index 000000000..3e4222264 --- /dev/null +++ b/comfy_extras/nodes_lt_audio.py @@ -0,0 +1,225 @@ +import folder_paths +import comfy.utils +import comfy.model_management +import torch + +from comfy.ldm.lightricks.vae.audio_vae import AudioVAE +from comfy_api.latest import ComfyExtension, io + + +class LTXVAudioVAELoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAELoader", + display_name="LTXV Audio VAE Loader", + category="audio", + inputs=[ + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + tooltip="Audio VAE checkpoint to load.", + ) + ], + outputs=[io.Vae.Output(display_name="Audio VAE")], + ) + + @classmethod + def execute(cls, ckpt_name: str) -> io.NodeOutput: + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) + return io.NodeOutput(AudioVAE(sd, metadata)) + + +class LTXVAudioVAEEncode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEEncode", + display_name="LTXV Audio VAE Encode", + category="audio", + inputs=[ + io.Audio.Input("audio", tooltip="The audio to be encoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to use for encoding.", + ), + ], + outputs=[io.Latent.Output(display_name="Audio Latent")], + ) + + @classmethod + def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latents = audio_vae.encode(audio) + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": int(audio_vae.sample_rate), + "type": "audio", + } + ) + + +class LTXVAudioVAEDecode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEDecode", + display_name="LTXV Audio VAE Decode", + category="audio", + inputs=[ + io.Latent.Input("samples", tooltip="The latent to be decoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model used for decoding the latent.", + ), + ], + outputs=[io.Audio.Output(display_name="Audio")], + ) + + @classmethod + def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latent = samples["samples"] + if audio_latent.is_nested: + audio_latent = audio_latent.unbind()[-1] + audio = audio_vae.decode(audio_latent).to(audio_latent.device) + output_audio_sample_rate = audio_vae.output_sample_rate + return io.NodeOutput( + { + "waveform": audio, + "sample_rate": int(output_audio_sample_rate), + } + ) + + +class LTXVEmptyLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVEmptyLatentAudio", + display_name="LTXV Empty Latent Audio", + category="latent/audio", + inputs=[ + io.Int.Input( + "frames_number", + default=97, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames.", + ), + io.Int.Input( + "frame_rate", + default=25, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames per second.", + ), + io.Int.Input( + "batch_size", + default=1, + min=1, + max=4096, + display_mode=io.NumberDisplay.number, + tooltip="The number of latent audio samples in the batch.", + ), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to get configuration from.", + ), + ], + outputs=[io.Latent.Output(display_name="Latent")], + ) + + @classmethod + def execute( + cls, + frames_number: int, + frame_rate: int, + batch_size: int, + audio_vae: AudioVAE, + ) -> io.NodeOutput: + """Generate empty audio latents matching the reference pipeline structure.""" + + assert audio_vae is not None, "Audio VAE model is required" + + z_channels = audio_vae.latent_channels + audio_freq = audio_vae.latent_frequency_bins + sampling_rate = int(audio_vae.sample_rate) + + num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + + audio_latents = torch.zeros( + (batch_size, z_channels, num_audio_latents, audio_freq), + device=comfy.model_management.intermediate_device(), + ) + + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": sampling_rate, + "type": "audio", + } + ) + + +class LTXAVTextEncoderLoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXAVTextEncoderLoader", + display_name="LTXV Audio Text Encoder Loader", + category="advanced/loaders", + description="[Recipes]\n\nltxav: gemma 3 12B", + inputs=[ + io.Combo.Input( + "text_encoder", + options=folder_paths.get_filename_list("text_encoders"), + ), + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + ), + io.Combo.Input( + "device", + options=["default", "cpu"], + advanced=True, + ) + ], + outputs=[io.Clip.Output()], + ) + + @classmethod + def execute(cls, text_encoder, ckpt_name, device="default"): + clip_type = comfy.sd.CLIPType.LTXV + + clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", text_encoder) + clip_path2 = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + + model_options = {} + if device == "cpu": + model_options["load_device"] = model_options["offload_device"] = torch.device("cpu") + + clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type, model_options=model_options) + return io.NodeOutput(clip) + + +class LTXVAudioExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LTXVAudioVAELoader, + LTXVAudioVAEEncode, + LTXVAudioVAEDecode, + LTXVEmptyLatentAudio, + LTXAVTextEncoderLoader, + ] + + +async def comfy_entrypoint() -> ComfyExtension: + return LTXVAudioExtension() diff --git a/comfy_extras/nodes_lt_upsampler.py b/comfy_extras/nodes_lt_upsampler.py new file mode 100644 index 000000000..f99ba13fb --- /dev/null +++ b/comfy_extras/nodes_lt_upsampler.py @@ -0,0 +1,75 @@ +from comfy import model_management +import math + +class LTXVLatentUpsampler: + """ + Upsamples a video latent by a factor of 2. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "upscale_model": ("LATENT_UPSCALE_MODEL",), + "vae": ("VAE",), + } + } + + RETURN_TYPES = ("LATENT",) + FUNCTION = "upsample_latent" + CATEGORY = "latent/video" + EXPERIMENTAL = True + + def upsample_latent( + self, + samples: dict, + upscale_model, + vae, + ) -> tuple: + """ + Upsample the input latent using the provided model. + + Args: + samples (dict): Input latent samples + upscale_model (LatentUpsampler): Loaded upscale model + vae: VAE model for normalization + auto_tiling (bool): Whether to automatically tile the input for processing + + Returns: + tuple: Tuple containing the upsampled latent + """ + device = model_management.get_torch_device() + memory_required = model_management.module_size(upscale_model) + + model_dtype = next(upscale_model.parameters()).dtype + latents = samples["samples"] + input_dtype = latents.dtype + + memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate + model_management.free_memory(memory_required, device) + + try: + upscale_model.to(device) # TODO: use the comfy model management system. + + latents = latents.to(dtype=model_dtype, device=device) + + """Upsample latents without tiling.""" + latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents) + upsampled_latents = upscale_model(latents) + finally: + upscale_model.cpu() + + upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize( + upsampled_latents + ) + upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device()) + return_dict = samples.copy() + return_dict["samples"] = upsampled_latents + return_dict.pop("noise_mask", None) + return (return_dict,) + + +NODE_CLASS_MAPPINGS = { + "LTXVLatentUpsampler": LTXVLatentUpsampler, +} diff --git a/comfy_extras/nodes_lumina2.py b/comfy_extras/nodes_lumina2.py index 89ff2397a..b35ab8b7d 100644 --- a/comfy_extras/nodes_lumina2.py +++ b/comfy_extras/nodes_lumina2.py @@ -12,8 +12,8 @@ class RenormCFG(io.ComfyNode): category="advanced/model", inputs=[ io.Model.Input("model"), - io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01), - io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01), + io.Float.Input("cfg_trunc", default=100, min=0.0, max=100.0, step=0.01, advanced=True), + io.Float.Input("renorm_cfg", default=1.0, min=0.0, max=100.0, step=0.01, advanced=True), ], outputs=[ io.Model.Output(), @@ -79,6 +79,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeLumina2", + search_aliases=["lumina prompt"], display_name="CLIP Text Encode for Lumina2", category="conditioning", description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " diff --git a/comfy_extras/nodes_mahiro.py b/comfy_extras/nodes_mahiro.py index 07b3353f4..a25226e6d 100644 --- a/comfy_extras/nodes_mahiro.py +++ b/comfy_extras/nodes_mahiro.py @@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Mahiro", - display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)", + display_name="Positive-Biased Guidance", category="_for_testing", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", inputs=[ @@ -20,27 +20,35 @@ class Mahiro(io.ComfyNode): io.Model.Output(display_name="patched_model"), ], is_experimental=True, + search_aliases=[ + "mahiro", + "mahiro cfg", + "similarity-adaptive guidance", + "positive-biased cfg", + ], ) @classmethod def execute(cls, model) -> io.NodeOutput: m = model.clone() + def mahiro_normd(args): - scale: float = args['cond_scale'] - cond_p: torch.Tensor = args['cond_denoised'] - uncond_p: torch.Tensor = args['uncond_denoised'] - #naive leap + scale: float = args["cond_scale"] + cond_p: torch.Tensor = args["cond_denoised"] + uncond_p: torch.Tensor = args["uncond_denoised"] + # naive leap leap = cond_p * scale - #sim with uncond leap + # sim with uncond leap u_leap = uncond_p * scale cfg = args["denoised"] merge = (leap + cfg) / 2 normu = torch.sqrt(u_leap.abs()) * u_leap.sign() normm = torch.sqrt(merge.abs()) * merge.sign() sim = F.cosine_similarity(normu, normm).mean() - simsc = 2 * (sim+1) - wm = (simsc*cfg + (4-simsc)*leap) / 4 + simsc = 2 * (sim + 1) + wm = (simsc * cfg + (4 - simsc) * leap) / 4 return wm + m.set_model_sampler_post_cfg_function(mahiro_normd) return io.NodeOutput(m) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index 290e6f55e..c44602597 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -50,6 +50,7 @@ class LatentCompositeMasked(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="LatentCompositeMasked", + search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"], category="latent", inputs=[ IO.Latent.Input("destination"), @@ -78,6 +79,7 @@ class ImageCompositeMasked(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCompositeMasked", + search_aliases=["paste image", "overlay", "layer"], category="image", inputs=[ IO.Image.Input("destination"), @@ -105,6 +107,7 @@ class MaskToImage(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskToImage", + search_aliases=["convert mask"], display_name="Convert Mask to Image", category="mask", inputs=[ @@ -126,6 +129,7 @@ class ImageToMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageToMask", + search_aliases=["extract channel", "channel to mask"], display_name="Convert Image to Mask", category="mask", inputs=[ @@ -149,6 +153,7 @@ class ImageColorToMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageColorToMask", + search_aliases=["color keying", "chroma key"], category="mask", inputs=[ IO.Image.Input("image"), @@ -194,6 +199,7 @@ class InvertMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="InvertMask", + search_aliases=["reverse mask", "flip mask"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -214,6 +220,7 @@ class CropMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="CropMask", + search_aliases=["cut mask", "extract mask region", "mask slice"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -239,6 +246,7 @@ class MaskComposite(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskComposite", + search_aliases=["combine masks", "blend masks", "layer masks"], category="mask", inputs=[ IO.Mask.Input("destination"), @@ -287,6 +295,7 @@ class FeatherMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="FeatherMask", + search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -333,12 +342,13 @@ class GrowMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="GrowMask", + search_aliases=["expand mask", "shrink mask"], display_name="Grow Mask", category="mask", inputs=[ IO.Mask.Input("mask"), IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), - IO.Boolean.Input("tapered_corners", default=True), + IO.Boolean.Input("tapered_corners", default=True, advanced=True), ], outputs=[IO.Mask.Output()], ) @@ -370,6 +380,7 @@ class ThresholdMask(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ThresholdMask", + search_aliases=["binary mask"], category="mask", inputs=[ IO.Mask.Input("mask"), @@ -394,6 +405,7 @@ class MaskPreview(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="MaskPreview", + search_aliases=["show mask", "view mask", "inspect mask", "debug mask"], display_name="Preview Mask", category="mask", description="Saves the input images to your ComfyUI output directory.", diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py new file mode 100644 index 000000000..6417bacf1 --- /dev/null +++ b/comfy_extras/nodes_math.py @@ -0,0 +1,119 @@ +"""Math expression node using simpleeval for safe evaluation. + +Provides a ComfyMathExpression node that evaluates math expressions +against dynamically-grown numeric inputs. +""" + +from __future__ import annotations + +import math +import string + +from simpleeval import simple_eval +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +MAX_EXPONENT = 4000 + + +def _variadic_sum(*args): + """Support both sum(values) and sum(a, b, c).""" + if len(args) == 1 and hasattr(args[0], "__iter__"): + return sum(args[0]) + return sum(args) + + +def _safe_pow(base, exp): + """Wrap pow() with an exponent cap to prevent DoS via huge exponents. + + The ** operator is already guarded by simpleeval's safe_power, but + pow() as a callable bypasses that guard. + """ + if abs(exp) > MAX_EXPONENT: + raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})") + return pow(base, exp) + + +MATH_FUNCTIONS = { + "sum": _variadic_sum, + "min": min, + "max": max, + "abs": abs, + "round": round, + "pow": _safe_pow, + "sqrt": math.sqrt, + "ceil": math.ceil, + "floor": math.floor, + "log": math.log, + "log2": math.log2, + "log10": math.log10, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "int": int, + "float": float, +} + + +class MathExpressionNode(io.ComfyNode): + """Evaluates a math expression against dynamically-grown inputs.""" + + @classmethod + def define_schema(cls) -> io.Schema: + autogrow = io.Autogrow.TemplateNames( + input=io.MultiType.Input("value", [io.Float, io.Int]), + names=list(string.ascii_lowercase), + min=1, + ) + return io.Schema( + node_id="ComfyMathExpression", + display_name="Math Expression", + category="math", + search_aliases=[ + "expression", "formula", "calculate", "calculator", + "eval", "math", + ], + inputs=[ + io.String.Input("expression", default="a + b", multiline=True), + io.Autogrow.Input("values", template=autogrow), + ], + outputs=[ + io.Float.Output(display_name="FLOAT"), + io.Int.Output(display_name="INT"), + ], + ) + + @classmethod + def execute( + cls, expression: str, values: io.Autogrow.Type + ) -> io.NodeOutput: + if not expression.strip(): + raise ValueError("Expression cannot be empty.") + + context: dict = dict(values) + context["values"] = list(values.values()) + + result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS) + # bool check must come first because bool is a subclass of int in Python + if isinstance(result, bool) or not isinstance(result, (int, float)): + raise ValueError( + f"Math Expression '{expression}' must evaluate to a numeric result, " + f"got {type(result).__name__}: {result!r}" + ) + if not math.isfinite(result): + raise ValueError( + f"Math Expression '{expression}' produced a non-finite result: {result}" + ) + return io.NodeOutput(float(result), int(result)) + + +class MathExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [MathExpressionNode] + + +async def comfy_entrypoint() -> MathExtension: + return MathExtension() diff --git a/comfy_extras/nodes_model_advanced.py b/comfy_extras/nodes_model_advanced.py index ae5d2c563..8bf6a1afa 100644 --- a/comfy_extras/nodes_model_advanced.py +++ b/comfy_extras/nodes_model_advanced.py @@ -52,8 +52,8 @@ class ModelSamplingDiscrete: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img"],), - "zsnr": ("BOOLEAN", {"default": False}), + "sampling": (["eps", "v_prediction", "lcm", "x0", "img_to_img", "img_to_img_flow"],), + "zsnr": ("BOOLEAN", {"default": False, "advanced": True}), }} RETURN_TYPES = ("MODEL",) @@ -76,6 +76,8 @@ class ModelSamplingDiscrete: sampling_type = comfy.model_sampling.X0 elif sampling == "img_to_img": sampling_type = comfy.model_sampling.IMG_TO_IMG + elif sampling == "img_to_img_flow": + sampling_type = comfy.model_sampling.IMG_TO_IMG_FLOW class ModelSamplingAdvanced(sampling_base, sampling_type): pass @@ -153,8 +155,8 @@ class ModelSamplingFlux: @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}), - "base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}), + "max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01, "advanced": True}), + "base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "advanced": True}), "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}), }} @@ -190,8 +192,8 @@ class ModelSamplingContinuousEDM: def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "sampling": (["v_prediction", "edm", "edm_playground_v2.5", "eps", "cosmos_rflow"],), - "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + "sigma_max": ("FLOAT", {"default": 120.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False, "advanced": True}), + "sigma_min": ("FLOAT", {"default": 0.002, "min": 0.0, "max": 1000.0, "step":0.001, "round": False, "advanced": True}), }} RETURN_TYPES = ("MODEL",) @@ -235,8 +237,8 @@ class ModelSamplingContinuousV: def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), "sampling": (["v_prediction"],), - "sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False}), + "sigma_max": ("FLOAT", {"default": 500.0, "min": 0.0, "max": 1000.0, "step":0.001, "round": False, "advanced": True}), + "sigma_min": ("FLOAT", {"default": 0.03, "min": 0.0, "max": 1000.0, "step":0.001, "round": False, "advanced": True}), }} RETURN_TYPES = ("MODEL",) @@ -299,10 +301,11 @@ class RescaleCFG: return (m, ) class ModelComputeDtype: + SEARCH_ALIASES = ["model precision", "change dtype"] @classmethod def INPUT_TYPES(s): return {"required": { "model": ("MODEL",), - "dtype": (["default", "fp32", "fp16", "bf16"],), + "dtype": (["default", "fp32", "fp16", "bf16"], {"advanced": True}), }} RETURN_TYPES = ("MODEL",) diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index dec2ae841..24d47a903 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -13,11 +13,11 @@ class PatchModelAddDownscale(io.ComfyNode): category="model_patches/unet", inputs=[ io.Model.Input("model"), - io.Int.Input("block_number", default=3, min=1, max=32, step=1), + io.Int.Input("block_number", default=3, min=1, max=32, step=1, advanced=True), io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001), - io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), - io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001), - io.Boolean.Input("downscale_after_skip", default=True), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001, advanced=True), + io.Boolean.Input("downscale_after_skip", default=True, advanced=True), io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS), io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS), ], diff --git a/comfy_extras/nodes_model_merging.py b/comfy_extras/nodes_model_merging.py index f20beab7d..5384ed531 100644 --- a/comfy_extras/nodes_model_merging.py +++ b/comfy_extras/nodes_model_merging.py @@ -91,6 +91,7 @@ class CLIPMergeSimple: class CLIPSubtract: + SEARCH_ALIASES = ["clip difference", "text encoder subtract"] @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",), @@ -113,6 +114,7 @@ class CLIPSubtract: class CLIPAdd: + SEARCH_ALIASES = ["combine clip"] @classmethod def INPUT_TYPES(s): return {"required": { "clip1": ("CLIP",), @@ -225,6 +227,7 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi comfy.sd.save_checkpoint(output_checkpoint, model, clip, vae, clip_vision, metadata=metadata, extra_keys=extra_keys) class CheckpointSave: + SEARCH_ALIASES = ["save model", "export checkpoint", "merge save"] def __init__(self): self.output_dir = folder_paths.get_output_directory() @@ -337,6 +340,7 @@ class VAESave: return {} class ModelSave: + SEARCH_ALIASES = ["export model", "checkpoint save"] def __init__(self): self.output_dir = folder_paths.get_output_directory() diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 2a0cfcf18..176e6bc2f 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -7,6 +7,7 @@ import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats import comfy.ldm.lumina.controlnet +from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel class BlockWiseControlBlock(torch.nn.Module): @@ -244,6 +245,10 @@ class ModelPatchLoader: elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet sd = z_image_convert(sd) config = {} + if 'control_layers.4.adaLN_modulation.0.weight' not in sd: + config['n_control_layers'] = 3 + config['additional_in_dim'] = 17 + config['refiner_control'] = True if 'control_layers.14.adaLN_modulation.0.weight' in sd: config['n_control_layers'] = 15 config['additional_in_dim'] = 17 @@ -253,10 +258,18 @@ class ModelPatchLoader: if torch.count_nonzero(ref_weight) == 0: config['broken'] = True model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config) + elif "audio_proj.proj1.weight" in sd: + model = MultiTalkModelPatch( + audio_window=5, context_tokens=32, vae_scale=4, + in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0], + intermediate_dim=sd["audio_proj.proj1.weight"].shape[0], + out_dim=sd["audio_proj.norm.weight"].shape[0], + device=comfy.model_management.unet_offload_device(), + operations=comfy.ops.manual_cast) - model.load_state_dict(sd) - model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) - return (model,) + model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) + model.load_state_dict(sd, assign=model_patcher.is_dynamic()) + return (model_patcher,) class DiffSynthCnetPatch: @@ -348,7 +361,7 @@ class ZImageControlPatch: if self.mask is None: mask_ = torch.zeros_like(inpaint_image_latent)[:, :1] else: - mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") + mask_ = comfy.utils.common_upscale(self.mask.view(self.mask.shape[0], -1, self.mask.shape[-2], self.mask.shape[-1]).mean(dim=1, keepdim=True).to(device=inpaint_image_latent.device), inpaint_image_latent.shape[-1], inpaint_image_latent.shape[-2], "nearest", "center") if latent_image is None: latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(torch.ones_like(inpaint_image) * 0.5)) @@ -520,6 +533,38 @@ class USOStyleReference: return (model_patched,) +class MultiTalkModelPatch(torch.nn.Module): + def __init__( + self, + audio_window: int = 5, + intermediate_dim: int = 512, + in_dim: int = 5120, + out_dim: int = 768, + context_tokens: int = 32, + vae_scale: int = 4, + num_layers: int = 40, + + device=None, dtype=None, operations=None + ): + super().__init__() + self.audio_proj = MultiTalkAudioProjModel( + seq_len=audio_window, + seq_len_vf=audio_window+vae_scale-1, + intermediate_dim=intermediate_dim, + out_dim=out_dim, + context_tokens=context_tokens, + device=device, + dtype=dtype, + operations=operations + ) + self.blocks = torch.nn.ModuleList( + [ + WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ] + ) + + NODE_CLASS_MAPPINGS = { "ModelPatchLoader": ModelPatchLoader, "QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet, diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index 67377e1bc..4ab2fb7e8 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -12,6 +12,7 @@ class Morphology(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Morphology", + search_aliases=["erode", "dilate"], display_name="ImageMorphology", category="image/postprocessing", inputs=[ @@ -57,6 +58,7 @@ class ImageRGBToYUV(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageRGBToYUV", + search_aliases=["color space conversion"], category="image/batch", inputs=[ io.Image.Input("image"), @@ -78,6 +80,7 @@ class ImageYUVToRGB(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageYUVToRGB", + search_aliases=["color space conversion"], category="image/batch", inputs=[ io.Image.Input("Y"), diff --git a/comfy_extras/nodes_nag.py b/comfy_extras/nodes_nag.py new file mode 100644 index 000000000..b57181848 --- /dev/null +++ b/comfy_extras/nodes_nag.py @@ -0,0 +1,99 @@ +import torch +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + + +class NAGuidance(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="NAGuidance", + display_name="Normalized Attention Guidance", + description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.", + category="advanced/guidance", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="The model to apply NAG to."), + io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."), + io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."), + io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01), + # io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."), + # io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."), + ], + outputs=[ + io.Model.Output(tooltip="The patched model with NAG enabled."), + ], + ) + + @classmethod + def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput: + m = model.clone() + + # sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent) + # sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent) + + def nag_attention_output_patch(out, extra_options): + cond_or_uncond = extra_options.get("cond_or_uncond", None) + if cond_or_uncond is None: + return out + + if not (1 in cond_or_uncond and 0 in cond_or_uncond): + return out + + # sigma = extra_options.get("sigmas", None) + # if sigma is not None and len(sigma) > 0: + # sigma = sigma[0].item() + # if sigma > sigma_start or sigma < sigma_end: + # return out + + img_slice = extra_options.get("img_slice", None) + + if img_slice is not None: + orig_out = out + out = out[:, img_slice[0]:img_slice[1]] # only apply on img part + + batch_size = out.shape[0] + half_size = batch_size // len(cond_or_uncond) + + ind_neg = cond_or_uncond.index(1) + ind_pos = cond_or_uncond.index(0) + z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)] + z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)] + + guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0) + + eps = 1e-6 + norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps) + norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps) + + ratio = norm_guided / norm_pos + scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio + + guided_normalized = guided * scale_factor + + z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha) + + if img_slice is not None: + orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final + orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final + return orig_out + else: + out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final + return out + + m.set_model_attn1_output_patch(nag_attention_output_patch) + m.disable_model_cfg1_optimization() + + return io.NodeOutput(m) + + +class NagExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + NAGuidance, + ] + + +async def comfy_entrypoint() -> NagExtension: + return NagExtension() diff --git a/comfy_extras/nodes_number_convert.py b/comfy_extras/nodes_number_convert.py new file mode 100644 index 000000000..b2822c856 --- /dev/null +++ b/comfy_extras/nodes_number_convert.py @@ -0,0 +1,79 @@ +"""Number Convert node for unified numeric type conversion. + +Provides a single node that converts INT, FLOAT, STRING, and BOOL +inputs into FLOAT and INT outputs. +""" + +from __future__ import annotations + +import math + +from typing_extensions import override + +from comfy_api.latest import ComfyExtension, io + + +class NumberConvertNode(io.ComfyNode): + """Converts various types to numeric FLOAT and INT outputs.""" + + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="ComfyNumberConvert", + display_name="Number Convert", + category="math", + search_aliases=[ + "int to float", "float to int", "number convert", + "int2float", "float2int", "cast", "parse number", + "string to number", "bool to int", + ], + inputs=[ + io.MultiType.Input( + "value", + [io.Int, io.Float, io.String, io.Boolean], + display_name="value", + ), + ], + outputs=[ + io.Float.Output(display_name="FLOAT"), + io.Int.Output(display_name="INT"), + ], + ) + + @classmethod + def execute(cls, value) -> io.NodeOutput: + if isinstance(value, bool): + float_val = 1.0 if value else 0.0 + elif isinstance(value, (int, float)): + float_val = float(value) + elif isinstance(value, str): + text = value.strip() + if not text: + raise ValueError("Cannot convert empty string to number.") + try: + float_val = float(text) + except ValueError: + raise ValueError( + f"Cannot convert string to number: {value!r}" + ) from None + else: + raise TypeError( + f"Unsupported input type: {type(value).__name__}" + ) + + if not math.isfinite(float_val): + raise ValueError( + f"Cannot convert non-finite value to number: {float_val}" + ) + + return io.NodeOutput(float_val, int(float_val)) + + +class NumberConvertExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [NumberConvertNode] + + +async def comfy_entrypoint() -> NumberConvertExtension: + return NumberConvertExtension() diff --git a/comfy_extras/nodes_painter.py b/comfy_extras/nodes_painter.py new file mode 100644 index 000000000..b9ecdf5ea --- /dev/null +++ b/comfy_extras/nodes_painter.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +import hashlib +import os + +import numpy as np +import torch +from PIL import Image + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io, UI +from typing_extensions import override + + +def hex_to_rgb(hex_color: str) -> tuple[float, float, float]: + hex_color = hex_color.lstrip("#") + if len(hex_color) != 6: + return (0.0, 0.0, 0.0) + r = int(hex_color[0:2], 16) / 255.0 + g = int(hex_color[2:4], 16) / 255.0 + b = int(hex_color[4:6], 16) / 255.0 + return (r, g, b) + + +class PainterNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Painter", + display_name="Painter", + category="image", + inputs=[ + io.Image.Input( + "image", + optional=True, + tooltip="Optional base image to paint over", + ), + io.String.Input( + "mask", + default="", + socketless=True, + extra_dict={"widgetType": "PAINTER", "image_upload": True}, + ), + io.Int.Input( + "width", + default=512, + min=64, + max=4096, + step=64, + socketless=True, + extra_dict={"hidden": True}, + ), + io.Int.Input( + "height", + default=512, + min=64, + max=4096, + step=64, + socketless=True, + extra_dict={"hidden": True}, + ), + io.Color.Input("bg_color", default="#000000"), + ], + outputs=[ + io.Image.Output("IMAGE"), + io.Mask.Output("MASK"), + ], + ) + + @classmethod + def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput: + if image is not None: + base_image = image[:1] + h, w = base_image.shape[1], base_image.shape[2] + else: + h, w = height, width + r, g, b = hex_to_rgb(bg_color) + base_image = torch.zeros((1, h, w, 3), dtype=torch.float32) + base_image[0, :, :, 0] = r + base_image[0, :, :, 1] = g + base_image[0, :, :, 2] = b + + if mask and mask.strip(): + mask_path = folder_paths.get_annotated_filepath(mask) + painter_img = node_helpers.pillow(Image.open, mask_path) + painter_img = painter_img.convert("RGBA") + + if painter_img.size != (w, h): + painter_img = painter_img.resize((w, h), Image.LANCZOS) + + painter_np = np.array(painter_img).astype(np.float32) / 255.0 + painter_rgb = painter_np[:, :, :3] + painter_alpha = painter_np[:, :, 3:4] + + mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0) + + base_np = base_image[0].cpu().numpy() + composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha) + out_image = torch.from_numpy(composited).unsqueeze(0) + else: + mask_tensor = torch.zeros((1, h, w), dtype=torch.float32) + out_image = base_image + + return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image)) + + @classmethod + def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None): + if mask and mask.strip(): + mask_path = folder_paths.get_annotated_filepath(mask) + if os.path.exists(mask_path): + m = hashlib.sha256() + with open(mask_path, "rb") as f: + m.update(f.read()) + return m.digest().hex() + return "" + + + +class PainterExtension(ComfyExtension): + @override + async def get_node_list(self): + return [PainterNode] + + +async def comfy_entrypoint(): + return PainterExtension() diff --git a/comfy_extras/nodes_perpneg.py b/comfy_extras/nodes_perpneg.py index cd068ce9c..ed1467de9 100644 --- a/comfy_extras/nodes_perpneg.py +++ b/comfy_extras/nodes_perpneg.py @@ -29,7 +29,7 @@ class PerpNeg(io.ComfyNode): inputs=[ io.Model.Input("model"), io.Conditioning.Input("empty_conditioning"), - io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01), + io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01, advanced=True), ], outputs=[ io.Model.Output(), @@ -134,7 +134,7 @@ class PerpNegGuider(io.ComfyNode): io.Conditioning.Input("negative"), io.Conditioning.Input("empty_conditioning"), io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), - io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01), + io.Float.Input("neg_scale", default=1.0, min=0.0, max=100.0, step=0.01, advanced=True), ], outputs=[ io.Guider.Output(), diff --git a/comfy_extras/nodes_pixart.py b/comfy_extras/nodes_pixart.py index a23e87b1f..2f1b73e60 100644 --- a/comfy_extras/nodes_pixart.py +++ b/comfy_extras/nodes_pixart.py @@ -7,6 +7,7 @@ class CLIPTextEncodePixArtAlpha(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodePixArtAlpha", + search_aliases=["pixart prompt"], category="advanced/conditioning", description="Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma.", inputs=[ diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index ca2cdeb50..06626f9dd 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -4,18 +4,24 @@ import torch import torch.nn.functional as F from PIL import Image import math +from enum import Enum +from typing import TypedDict, Literal import comfy.utils import comfy.model_management +from comfy_extras.nodes_latent import reshape_latent_to import node_helpers from comfy_api.latest import ComfyExtension, io +from nodes import MAX_RESOLUTION class Blend(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="ImageBlend", + display_name="Image Blend", category="image/postprocessing", + essentials_category="Image Tools", inputs=[ io.Image.Input("image1"), io.Image.Input("image2"), @@ -72,6 +78,7 @@ class Blur(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ImageBlur", + display_name="Image Blur", category="image/postprocessing", inputs=[ io.Image.Input("image"), @@ -175,9 +182,9 @@ class Sharpen(io.ComfyNode): category="image/postprocessing", inputs=[ io.Image.Input("image"), - io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1), - io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01), - io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01), + io.Int.Input("sharpen_radius", default=1, min=1, max=31, step=1, advanced=True), + io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.01, advanced=True), + io.Float.Input("alpha", default=1.0, min=0.0, max=5.0, step=0.01, advanced=True), ], outputs=[ io.Image.Output(), @@ -221,7 +228,7 @@ class ImageScaleToTotalPixels(io.ComfyNode): io.Image.Input("image"), io.Combo.Input("upscale_method", options=cls.upscale_methods), io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01), - io.Int.Input("resolution_steps", default=1, min=1, max=256), + io.Int.Input("resolution_steps", default=1, min=1, max=256, advanced=True), ], outputs=[ io.Image.Output(), @@ -241,6 +248,418 @@ class ImageScaleToTotalPixels(io.ComfyNode): s = s.movedim(1,-1) return io.NodeOutput(s) +class ResizeType(str, Enum): + SCALE_BY = "scale by multiplier" + SCALE_DIMENSIONS = "scale dimensions" + SCALE_LONGER_DIMENSION = "scale longer dimension" + SCALE_SHORTER_DIMENSION = "scale shorter dimension" + SCALE_WIDTH = "scale width" + SCALE_HEIGHT = "scale height" + SCALE_TOTAL_PIXELS = "scale total pixels" + MATCH_SIZE = "match size" + SCALE_TO_MULTIPLE = "scale to multiple" + +def is_image(input: torch.Tensor) -> bool: + # images have 4 dimensions: [batch, height, width, channels] + # masks have 3 dimensions: [batch, height, width] + return len(input.shape) == 4 + +def init_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(-1, 1) + else: + input = input.unsqueeze(1) + return input + +def finalize_image_mask_input(input: torch.Tensor, is_type_image: bool) -> torch.Tensor: + if is_type_image: + input = input.movedim(1, -1) + else: + input = input.squeeze(1) + return input + +def scale_by(input: torch.Tensor, multiplier: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = round(input.shape[-1] * multiplier) + height = round(input.shape[-2] * multiplier) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_dimensions(input: torch.Tensor, width: int, height: int, scale_method: str, crop: str="disabled") -> torch.Tensor: + if width == 0 and height == 0: + return input + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + + if width == 0: + width = max(1, round(input.shape[-1] * height / input.shape[-2])) + elif height == 0: + height = max(1, round(input.shape[-2] * width / input.shape[-1])) + + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_longer_dimension(input: torch.Tensor, longer_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height > width: + width = round((width / height) * longer_size) + height = longer_size + elif width > height: + height = round((height / width) * longer_size) + width = longer_size + else: + height = longer_size + width = longer_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_shorter_dimension(input: torch.Tensor, shorter_size: int, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + width = input.shape[-1] + height = input.shape[-2] + + if height < width: + width = round((width / height) * shorter_size) + height = shorter_size + elif width < height: + height = round((height / width) * shorter_size) + width = shorter_size + else: + height = shorter_size + width = shorter_size + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_total_pixels(input: torch.Tensor, megapixels: float, scale_method: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + total = int(megapixels * 1024 * 1024) + + scale_by = math.sqrt(total / (input.shape[-1] * input.shape[-2])) + width = round(input.shape[-1] * scale_by) + height = round(input.shape[-2] * scale_by) + + input = comfy.utils.common_upscale(input, width, height, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str, crop: str) -> torch.Tensor: + is_type_image = is_image(input) + input = init_image_mask_input(input, is_type_image) + match = init_image_mask_input(match, is_image(match)) + + width = match.shape[-1] + height = match.shape[-2] + input = comfy.utils.common_upscale(input, width, height, scale_method, crop) + input = finalize_image_mask_input(input, is_type_image) + return input + +def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor: + if multiple <= 1: + return input + is_type_image = is_image(input) + if is_type_image: + _, height, width, _ = input.shape + else: + _, height, width = input.shape + target_w = (width // multiple) * multiple + target_h = (height // multiple) * multiple + if target_w == 0 or target_h == 0: + return input + if target_w == width and target_h == height: + return input + s_w = target_w / width + s_h = target_h / height + if s_w >= s_h: + scaled_w = target_w + scaled_h = int(math.ceil(height * s_w)) + if scaled_h < target_h: + scaled_h = target_h + else: + scaled_h = target_h + scaled_w = int(math.ceil(width * s_h)) + if scaled_w < target_w: + scaled_w = target_w + input = init_image_mask_input(input, is_type_image) + input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled") + input = finalize_image_mask_input(input, is_type_image) + x0 = (scaled_w - target_w) // 2 + y0 = (scaled_h - target_h) // 2 + x1 = x0 + target_w + y1 = y0 + target_h + if is_type_image: + return input[:, y0:y1, x0:x1, :] + return input[:, y0:y1, x0:x1] + +class ResizeImageMaskNode(io.ComfyNode): + scale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop_methods = ["disabled", "center"] + + class ResizeTypedDict(TypedDict): + resize_type: ResizeType + scale_method: Literal["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] + crop: Literal["disabled", "center"] + multiplier: float + width: int + height: int + longer_size: int + shorter_size: int + megapixels: float + multiple: int + + @classmethod + def define_schema(cls): + template = io.MatchType.Template("input_type", [io.Image, io.Mask]) + crop_combo = io.Combo.Input( + "crop", + options=cls.crop_methods, + default="center", + tooltip="How to handle aspect ratio mismatch: 'disabled' stretches to fit, 'center' crops to maintain aspect ratio.", + ) + return io.Schema( + node_id="ResizeImageMaskNode", + display_name="Resize Image/Mask", + description="Resize an image or mask using various scaling methods.", + category="transform", + search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"], + inputs=[ + io.MatchType.Input("input", template=template), + io.DynamicCombo.Input( + "resize_type", + tooltip="Select how to resize: by exact dimensions, scale factor, matching another image, etc.", + options=[ + io.DynamicCombo.Option(ResizeType.SCALE_DIMENSIONS, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Set to 0 to auto-calculate from height while preserving aspect ratio."), + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Set to 0 to auto-calculate from width while preserving aspect ratio."), + crop_combo, + ]), + io.DynamicCombo.Option(ResizeType.SCALE_BY, [ + io.Float.Input("multiplier", default=1.00, min=0.01, max=8.0, step=0.01, tooltip="Scale factor (e.g., 2.0 doubles size, 0.5 halves size)."), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_LONGER_DIMENSION, [ + io.Int.Input("longer_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The longer edge will be resized to this value. Aspect ratio is preserved."), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_SHORTER_DIMENSION, [ + io.Int.Input("shorter_size", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="The shorter edge will be resized to this value. Aspect ratio is preserved."), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_WIDTH, [ + io.Int.Input("width", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target width in pixels. Height auto-adjusts to preserve aspect ratio."), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_HEIGHT, [ + io.Int.Input("height", default=512, min=0, max=MAX_RESOLUTION, step=1, tooltip="Target height in pixels. Width auto-adjusts to preserve aspect ratio."), + ]), + io.DynamicCombo.Option(ResizeType.SCALE_TOTAL_PIXELS, [ + io.Float.Input("megapixels", default=1.0, min=0.01, max=16.0, step=0.01, tooltip="Target total megapixels (e.g., 1.0 ≈ 1024×1024). Aspect ratio is preserved."), + ]), + io.DynamicCombo.Option(ResizeType.MATCH_SIZE, [ + io.MultiType.Input("match", [io.Image, io.Mask], tooltip="Resize input to match the dimensions of this reference image or mask."), + crop_combo, + ]), + io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [ + io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1, tooltip="Resize so width and height are divisible by this number. Useful for latent alignment (e.g., 8 or 64)."), + ]), + ], + ), + io.Combo.Input( + "scale_method", + options=cls.scale_methods, + default="area", + tooltip="Interpolation algorithm. 'area' is best for downscaling, 'lanczos' for upscaling, 'nearest-exact' for pixel art.", + ), + ], + outputs=[io.MatchType.Output(template=template, display_name="resized")] + ) + + @classmethod + def execute(cls, input: io.Image.Type | io.Mask.Type, scale_method: io.Combo.Type, resize_type: ResizeTypedDict) -> io.NodeOutput: + selected_type = resize_type["resize_type"] + if selected_type == ResizeType.SCALE_BY: + return io.NodeOutput(scale_by(input, resize_type["multiplier"], scale_method)) + elif selected_type == ResizeType.SCALE_DIMENSIONS: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], resize_type["height"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_LONGER_DIMENSION: + return io.NodeOutput(scale_longer_dimension(input, resize_type["longer_size"], scale_method)) + elif selected_type == ResizeType.SCALE_SHORTER_DIMENSION: + return io.NodeOutput(scale_shorter_dimension(input, resize_type["shorter_size"], scale_method)) + elif selected_type == ResizeType.SCALE_WIDTH: + return io.NodeOutput(scale_dimensions(input, resize_type["width"], 0, scale_method)) + elif selected_type == ResizeType.SCALE_HEIGHT: + return io.NodeOutput(scale_dimensions(input, 0, resize_type["height"], scale_method)) + elif selected_type == ResizeType.SCALE_TOTAL_PIXELS: + return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) + elif selected_type == ResizeType.MATCH_SIZE: + return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) + elif selected_type == ResizeType.SCALE_TO_MULTIPLE: + return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method)) + raise ValueError(f"Unsupported resize type: {selected_type}") + +def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: + if len(images) == 0: + return None + # first, get the max channels count + max_channels = max(image.shape[-1] for image in images) + # then, pad all images to have the same channels count + padded_images: list[torch.Tensor] = [] + for image in images: + if image.shape[-1] < max_channels: + padded_images.append(torch.nn.functional.pad(image, (0,1), mode='constant', value=1.0)) + else: + padded_images.append(image) + # resize all images to be the same size as the first image + resized_images: list[torch.Tensor] = [] + first_image_shape = padded_images[0].shape + for image in padded_images: + if image.shape[1:] != first_image_shape[1:]: + resized_images.append(comfy.utils.common_upscale(image.movedim(-1,1), first_image_shape[2], first_image_shape[1], "bilinear", "center").movedim(1,-1)) + else: + resized_images.append(image) + # batch the images in the format [b, h, w, c] + return torch.cat(resized_images, dim=0) + +def batch_masks(masks: list[torch.Tensor]) -> torch.Tensor | None: + if len(masks) == 0: + return None + # resize all masks to be the same size as the first mask + resized_masks: list[torch.Tensor] = [] + first_mask_shape = masks[0].shape + for mask in masks: + if mask.shape[1:] != first_mask_shape[1:]: + mask = init_image_mask_input(mask, is_type_image=False) + mask = comfy.utils.common_upscale(mask, first_mask_shape[2], first_mask_shape[1], "bilinear", "center") + resized_masks.append(finalize_image_mask_input(mask, is_type_image=False)) + else: + resized_masks.append(mask) + # batch the masks in the format [b, h, w] + return torch.cat(resized_masks, dim=0) + +def batch_latents(latents: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor] | None: + if len(latents) == 0: + return None + samples_out = latents[0].copy() + samples_out["batch_index"] = [] + first_samples = latents[0]["samples"] + tensors: list[torch.Tensor] = [] + for latent in latents: + # first, deal with latent tensors + tensors.append(reshape_latent_to(first_samples.shape, latent["samples"], repeat_batch=False)) + # next, deal with batch_index + samples_out["batch_index"].extend(latent.get("batch_index", [x for x in range(0, latent["samples"].shape[0])])) + samples_out["samples"] = torch.cat(tensors, dim=0) + return samples_out + +class BatchImagesNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Image.Input("image"), prefix="image", min=2, max=50) + return io.Schema( + node_id="BatchImagesNode", + display_name="Batch Images", + category="image", + essentials_category="Image Tools", + search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"], + inputs=[ + io.Autogrow.Input("images", template=autogrow_template) + ], + outputs=[ + io.Image.Output() + ] + ) + + @classmethod + def execute(cls, images: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_images(list(images.values()))) + +class BatchMasksNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Mask.Input("mask"), prefix="mask", min=2, max=50) + return io.Schema( + node_id="BatchMasksNode", + search_aliases=["combine masks", "stack masks", "merge masks"], + display_name="Batch Masks", + category="mask", + inputs=[ + io.Autogrow.Input("masks", template=autogrow_template) + ], + outputs=[ + io.Mask.Output() + ] + ) + + @classmethod + def execute(cls, masks: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_masks(list(masks.values()))) + +class BatchLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + autogrow_template = io.Autogrow.TemplatePrefix(io.Latent.Input("latent"), prefix="latent", min=2, max=50) + return io.Schema( + node_id="BatchLatentsNode", + search_aliases=["combine latents", "stack latents", "merge latents"], + display_name="Batch Latents", + category="latent", + inputs=[ + io.Autogrow.Input("latents", template=autogrow_template) + ], + outputs=[ + io.Latent.Output() + ] + ) + + @classmethod + def execute(cls, latents: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(batch_latents(list(latents.values()))) + +class BatchImagesMasksLatentsNode(io.ComfyNode): + @classmethod + def define_schema(cls): + matchtype_template = io.MatchType.Template("input", allowed_types=[io.Image, io.Mask, io.Latent]) + autogrow_template = io.Autogrow.TemplatePrefix( + io.MatchType.Input("input", matchtype_template), + prefix="input", min=1, max=50) + return io.Schema( + node_id="BatchImagesMasksLatentsNode", + search_aliases=["combine batch", "merge batch", "stack inputs"], + display_name="Batch Images/Masks/Latents", + category="util", + inputs=[ + io.Autogrow.Input("inputs", template=autogrow_template) + ], + outputs=[ + io.MatchType.Output(id=None, template=matchtype_template) + ] + ) + + @classmethod + def execute(cls, inputs: io.Autogrow.Type) -> io.NodeOutput: + batched = None + values = list(inputs.values()) + # latents + if isinstance(values[0], dict): + batched = batch_latents(values) + # images + elif is_image(values[0]): + batched = batch_images(values) + # masks + else: + batched = batch_masks(values) + return io.NodeOutput(batched) + + class PostProcessingExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -250,6 +669,11 @@ class PostProcessingExtension(ComfyExtension): Quantize, Sharpen, ImageScaleToTotalPixels, + ResizeImageMaskNode, + BatchImagesNode, + BatchMasksNode, + BatchLatentsNode, + # BatchImagesMasksLatentsNode, ] async def comfy_entrypoint() -> PostProcessingExtension: diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 139b07c93..b0a6f279d 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -16,6 +16,7 @@ class PreviewAny(): OUTPUT_NODE = True CATEGORY = "utils" + SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"] def main(self, source=None): value = 'None' diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 5a1aeba80..9c2e98758 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -29,6 +29,7 @@ class StringMultiline(io.ComfyNode): node_id="PrimitiveStringMultiline", display_name="String (Multiline)", category="utils/primitive", + essentials_category="Basics", inputs=[ io.String.Input("value", multiline=True), ], @@ -66,7 +67,7 @@ class Float(io.ComfyNode): display_name="Float", category="utils/primitive", inputs=[ - io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize), + io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1), ], outputs=[io.Float.Output()], ) diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index 525239ae5..6894367be 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -3,7 +3,9 @@ import comfy.utils import math from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import comfy.model_management +import torch +import nodes class TextEncodeQwenImageEdit(io.ComfyNode): @classmethod @@ -104,12 +106,37 @@ class TextEncodeQwenImageEditPlus(io.ComfyNode): return io.NodeOutput(conditioning) +class EmptyQwenImageLayeredLatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyQwenImageLayeredLatentImage", + display_name="Empty Qwen Image Layered Latent", + category="latent/qwen", + inputs=[ + io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("layers", default=3, min=0, max=nodes.MAX_RESOLUTION, step=1, advanced=True), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, layers, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, layers + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + class QwenExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ TextEncodeQwenImageEdit, TextEncodeQwenImageEditPlus, + EmptyQwenImageLayeredLatentImage, ] diff --git a/comfy_extras/nodes_replacements.py b/comfy_extras/nodes_replacements.py new file mode 100644 index 000000000..7684e854c --- /dev/null +++ b/comfy_extras/nodes_replacements.py @@ -0,0 +1,103 @@ +from comfy_api.latest import ComfyExtension, io, ComfyAPI + +api = ComfyAPI() + + +async def register_replacements(): + """Register all built-in node replacements.""" + await register_replacements_longeredge() + await register_replacements_batchimages() + await register_replacements_upscaleimage() + await register_replacements_controlnet() + await register_replacements_load3d() + await register_replacements_preview3d() + await register_replacements_svdimg2vid() + await register_replacements_conditioningavg() + +async def register_replacements_longeredge(): + # No dynamic inputs here + await api.node_replacement.register(io.NodeReplace( + new_node_id="ImageScaleToMaxDimension", + old_node_id="ResizeImagesByLongerEdge", + old_widget_ids=["longer_edge"], + input_mapping=[ + {"new_id": "image", "old_id": "images"}, + {"new_id": "largest_size", "old_id": "longer_edge"}, + {"new_id": "upscale_method", "set_value": "lanczos"}, + ], + # just to test the frontend output_mapping code, does nothing really here + output_mapping=[{"new_idx": 0, "old_idx": 0}], + )) + +async def register_replacements_batchimages(): + # BatchImages node uses Autogrow + await api.node_replacement.register(io.NodeReplace( + new_node_id="BatchImagesNode", + old_node_id="ImageBatch", + input_mapping=[ + {"new_id": "images.image0", "old_id": "image1"}, + {"new_id": "images.image1", "old_id": "image2"}, + ], + )) + +async def register_replacements_upscaleimage(): + # ResizeImageMaskNode uses DynamicCombo + await api.node_replacement.register(io.NodeReplace( + new_node_id="ResizeImageMaskNode", + old_node_id="ImageScaleBy", + old_widget_ids=["upscale_method", "scale_by"], + input_mapping=[ + {"new_id": "input", "old_id": "image"}, + {"new_id": "resize_type", "set_value": "scale by multiplier"}, + {"new_id": "resize_type.multiplier", "old_id": "scale_by"}, + {"new_id": "scale_method", "old_id": "upscale_method"}, + ], + )) + +async def register_replacements_controlnet(): + # T2IAdapterLoader → ControlNetLoader + await api.node_replacement.register(io.NodeReplace( + new_node_id="ControlNetLoader", + old_node_id="T2IAdapterLoader", + input_mapping=[ + {"new_id": "control_net_name", "old_id": "t2i_adapter_name"}, + ], + )) + +async def register_replacements_load3d(): + # Load3DAnimation merged into Load3D + await api.node_replacement.register(io.NodeReplace( + new_node_id="Load3D", + old_node_id="Load3DAnimation", + )) + +async def register_replacements_preview3d(): + # Preview3DAnimation merged into Preview3D + await api.node_replacement.register(io.NodeReplace( + new_node_id="Preview3D", + old_node_id="Preview3DAnimation", + )) + +async def register_replacements_svdimg2vid(): + # Typo fix: SDV → SVD + await api.node_replacement.register(io.NodeReplace( + new_node_id="SVD_img2vid_Conditioning", + old_node_id="SDV_img2vid_Conditioning", + )) + +async def register_replacements_conditioningavg(): + # Typo fix: trailing space in node name + await api.node_replacement.register(io.NodeReplace( + new_node_id="ConditioningAverage", + old_node_id="ConditioningAverage ", + )) + +class NodeReplacementsExtension(ComfyExtension): + async def on_load(self) -> None: + await register_replacements() + + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [] + +async def comfy_entrypoint() -> NodeReplacementsExtension: + return NodeReplacementsExtension() diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py new file mode 100644 index 000000000..520b4067e --- /dev/null +++ b/comfy_extras/nodes_resolution.py @@ -0,0 +1,86 @@ +from __future__ import annotations +import math +from enum import Enum +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class AspectRatio(str, Enum): + SQUARE = "1:1 (Square)" + PHOTO_H = "3:2 (Photo)" + STANDARD_H = "4:3 (Standard)" + WIDESCREEN_H = "16:9 (Widescreen)" + ULTRAWIDE_H = "21:9 (Ultrawide)" + PHOTO_V = "2:3 (Portrait Photo)" + STANDARD_V = "3:4 (Portrait Standard)" + WIDESCREEN_V = "9:16 (Portrait Widescreen)" + + +ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = { + AspectRatio.SQUARE: (1, 1), + AspectRatio.PHOTO_H: (3, 2), + AspectRatio.STANDARD_H: (4, 3), + AspectRatio.WIDESCREEN_H: (16, 9), + AspectRatio.ULTRAWIDE_H: (21, 9), + AspectRatio.PHOTO_V: (2, 3), + AspectRatio.STANDARD_V: (3, 4), + AspectRatio.WIDESCREEN_V: (9, 16), +} + + +class ResolutionSelector(io.ComfyNode): + """Calculate width and height from aspect ratio and megapixel target.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ResolutionSelector", + display_name="Resolution Selector", + category="utils", + description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.", + inputs=[ + io.Combo.Input( + "aspect_ratio", + options=AspectRatio, + default=AspectRatio.SQUARE, + tooltip="The aspect ratio for the output dimensions.", + ), + io.Float.Input( + "megapixels", + default=1.0, + min=0.1, + max=16.0, + step=0.1, + tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.", + ), + ], + outputs=[ + io.Int.Output( + "width", tooltip="Calculated width in pixels (multiple of 8)." + ), + io.Int.Output( + "height", tooltip="Calculated height in pixels (multiple of 8)." + ), + ], + ) + + @classmethod + def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput: + w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] + total_pixels = megapixels * 1024 * 1024 + scale = math.sqrt(total_pixels / (w_ratio * h_ratio)) + width = round(w_ratio * scale / 8) * 8 + height = round(h_ratio * scale / 8) * 8 + return io.NodeOutput(width, height) + + +class ResolutionExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ResolutionSelector, + ] + + +async def comfy_entrypoint() -> ResolutionExtension: + return ResolutionExtension() diff --git a/comfy_extras/nodes_rope.py b/comfy_extras/nodes_rope.py index d1feb031e..918ddc02b 100644 --- a/comfy_extras/nodes_rope.py +++ b/comfy_extras/nodes_rope.py @@ -12,14 +12,14 @@ class ScaleROPE(io.ComfyNode): is_experimental=True, inputs=[ io.Model.Input("model"), - io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1), - io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1), + io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1, advanced=True), + io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1, advanced=True), - io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1), - io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1), + io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1, advanced=True), + io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1, advanced=True), - io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1), - io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1), + io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1, advanced=True), + io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1, advanced=True), ], diff --git a/comfy_extras/nodes_sag.py b/comfy_extras/nodes_sag.py index 0f47db30b..d9c47851c 100644 --- a/comfy_extras/nodes_sag.py +++ b/comfy_extras/nodes_sag.py @@ -117,7 +117,7 @@ class SelfAttentionGuidance(io.ComfyNode): inputs=[ io.Model.Input("model"), io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01), - io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1), + io.Float.Input("blur_sigma", default=2.0, min=0.0, max=10.0, step=0.1, advanced=True), ], outputs=[ io.Model.Output(), diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 14782cb2b..c43844a1a 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -55,7 +55,7 @@ class EmptySD3LatentImage(io.ComfyNode): @classmethod def execute(cls, width, height, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples":latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8}) generate = execute # TODO: remove @@ -65,13 +65,14 @@ class CLIPTextEncodeSD3(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeSD3", + search_aliases=["sd3 prompt"], category="advanced/conditioning", inputs=[ io.Clip.Input("clip"), io.String.Input("clip_l", multiline=True, dynamic_prompts=True), io.String.Input("clip_g", multiline=True, dynamic_prompts=True), io.String.Input("t5xxl", multiline=True, dynamic_prompts=True), - io.Combo.Input("empty_padding", options=["none", "empty_prompt"]), + io.Combo.Input("empty_padding", options=["none", "empty_prompt"], advanced=True), ], outputs=[ io.Conditioning.Output(), @@ -178,10 +179,10 @@ class SkipLayerGuidanceSD3(io.ComfyNode): description="Generic version of SkipLayerGuidance node that can be used on every DiT model.", inputs=[ io.Model.Input("model"), - io.String.Input("layers", default="7, 8, 9", multiline=False), + io.String.Input("layers", default="7, 8, 9", multiline=False, advanced=True), io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), - io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), - io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001, advanced=True), ], outputs=[ io.Model.Output(), diff --git a/comfy_extras/nodes_sdpose.py b/comfy_extras/nodes_sdpose.py new file mode 100644 index 000000000..71441848e --- /dev/null +++ b/comfy_extras/nodes_sdpose.py @@ -0,0 +1,740 @@ +import torch +import comfy.utils +import numpy as np +import math +import colorsys +from tqdm import tqdm +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_extras.nodes_lotus import LotusConditioning + + +def _preprocess_keypoints(kp_raw, sc_raw): + """Insert neck keypoint and remap from MMPose to OpenPose ordering. + + Returns (kp, sc) where kp has shape (134, 2) and sc has shape (134,). + Layout: + 0-17 body (18 kp, OpenPose order) + 18-23 feet (6 kp) + 24-91 face (68 kp) + 92-112 right hand (21 kp) + 113-133 left hand (21 kp) + """ + kp = np.array(kp_raw, dtype=np.float32) + sc = np.array(sc_raw, dtype=np.float32) + if len(kp) >= 17: + neck = (kp[5] + kp[6]) / 2 + neck_score = min(sc[5], sc[6]) if sc[5] > 0.3 and sc[6] > 0.3 else 0 + kp = np.insert(kp, 17, neck, axis=0) + sc = np.insert(sc, 17, neck_score) + mmpose_idx = np.array([17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3]) + openpose_idx = np.array([ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17]) + tmp_kp, tmp_sc = kp.copy(), sc.copy() + tmp_kp[openpose_idx] = kp[mmpose_idx] + tmp_sc[openpose_idx] = sc[mmpose_idx] + kp, sc = tmp_kp, tmp_sc + return kp, sc + + +def _to_openpose_frames(all_keypoints, all_scores, height, width): + """Convert raw keypoint lists to a list of OpenPose-style frame dicts. + + Each frame dict contains: + canvas_width, canvas_height, people: list of person dicts with keys: + pose_keypoints_2d - 18 body kp as flat [x,y,score,...] (absolute pixels) + foot_keypoints_2d - 6 foot kp as flat [x,y,score,...] (absolute pixels) + face_keypoints_2d - 70 face kp as flat [x,y,score,...] (absolute pixels) + indices 0-67: 68 face landmarks + index 68: right eye (body[14]) + index 69: left eye (body[15]) + hand_right_keypoints_2d - 21 right-hand kp (absolute pixels) + hand_left_keypoints_2d - 21 left-hand kp (absolute pixels) + """ + def _flatten(kp_slice, sc_slice): + return np.stack([kp_slice[:, 0], kp_slice[:, 1], sc_slice], axis=1).flatten().tolist() + + frames = [] + for img_idx in range(len(all_keypoints)): + people = [] + for kp_raw, sc_raw in zip(all_keypoints[img_idx], all_scores[img_idx]): + kp, sc = _preprocess_keypoints(kp_raw, sc_raw) + # 70 face kp = 68 face landmarks + REye (body[14]) + LEye (body[15]) + face_kp = np.concatenate([kp[24:92], kp[[14, 15]]], axis=0) + face_sc = np.concatenate([sc[24:92], sc[[14, 15]]], axis=0) + people.append({ + "pose_keypoints_2d": _flatten(kp[0:18], sc[0:18]), + "foot_keypoints_2d": _flatten(kp[18:24], sc[18:24]), + "face_keypoints_2d": _flatten(face_kp, face_sc), + "hand_right_keypoints_2d": _flatten(kp[92:113], sc[92:113]), + "hand_left_keypoints_2d": _flatten(kp[113:134], sc[113:134]), + }) + frames.append({"canvas_width": width, "canvas_height": height, "people": people}) + return frames + + +class KeypointDraw: + """ + Pose keypoint drawing class that supports both numpy and cv2 backends. + """ + def __init__(self): + try: + import cv2 + self.draw = cv2 + except ImportError: + self.draw = self + + # Hand connections (same for both hands) + self.hand_edges = [ + [0, 1], [1, 2], [2, 3], [3, 4], # thumb + [0, 5], [5, 6], [6, 7], [7, 8], # index + [0, 9], [9, 10], [10, 11], [11, 12], # middle + [0, 13], [13, 14], [14, 15], [15, 16], # ring + [0, 17], [17, 18], [18, 19], [19, 20], # pinky + ] + + # Body connections - matching DWPose limbSeq (1-indexed, converted to 0-indexed) + self.body_limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], + [1, 16], [16, 18] + ] + + # Colors matching DWPose + self.colors = [ + [255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], + [85, 255, 0], [0, 255, 0], [0, 255, 85], [0, 255, 170], [0, 255, 255], + [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85] + ] + + @staticmethod + def circle(canvas_np, center, radius, color, **kwargs): + """Draw a filled circle using NumPy vectorized operations.""" + cx, cy = center + h, w = canvas_np.shape[:2] + + radius_int = int(np.ceil(radius)) + + y_min, y_max = max(0, cy - radius_int), min(h, cy + radius_int + 1) + x_min, x_max = max(0, cx - radius_int), min(w, cx + radius_int + 1) + + if y_max <= y_min or x_max <= x_min: + return + + y, x = np.ogrid[y_min:y_max, x_min:x_max] + mask = (x - cx)**2 + (y - cy)**2 <= radius**2 + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def line(canvas_np, pt1, pt2, color, thickness=1, **kwargs): + """Draw line using Bresenham's algorithm with NumPy operations.""" + x0, y0, x1, y1 = *pt1, *pt2 + h, w = canvas_np.shape[:2] + dx, dy = abs(x1 - x0), abs(y1 - y0) + sx, sy = (1 if x0 < x1 else -1), (1 if y0 < y1 else -1) + err, x, y, line_points = dx - dy, x0, y0, [] + + while True: + line_points.append((x, y)) + if x == x1 and y == y1: + break + e2 = 2 * err + if e2 > -dy: + err, x = err - dy, x + sx + if e2 < dx: + err, y = err + dx, y + sy + + if thickness > 1: + radius, radius_int = (thickness / 2.0) + 0.5, int(np.ceil((thickness / 2.0) + 0.5)) + for px, py in line_points: + y_min, y_max, x_min, x_max = max(0, py - radius_int), min(h, py + radius_int + 1), max(0, px - radius_int), min(w, px + radius_int + 1) + if y_max > y_min and x_max > x_min: + yy, xx = np.ogrid[y_min:y_max, x_min:x_max] + canvas_np[y_min:y_max, x_min:x_max][(xx - px)**2 + (yy - py)**2 <= radius**2] = color + else: + line_points = np.array(line_points) + valid = (line_points[:, 1] >= 0) & (line_points[:, 1] < h) & (line_points[:, 0] >= 0) & (line_points[:, 0] < w) + if (valid_points := line_points[valid]).size: + canvas_np[valid_points[:, 1], valid_points[:, 0]] = color + + @staticmethod + def fillConvexPoly(canvas_np, pts, color, **kwargs): + """Fill polygon using vectorized scanline algorithm.""" + if len(pts) < 3: + return + pts = np.array(pts, dtype=np.int32) + h, w = canvas_np.shape[:2] + y_min, y_max, x_min, x_max = max(0, pts[:, 1].min()), min(h, pts[:, 1].max() + 1), max(0, pts[:, 0].min()), min(w, pts[:, 0].max() + 1) + if y_max <= y_min or x_max <= x_min: + return + yy, xx = np.mgrid[y_min:y_max, x_min:x_max] + mask = np.zeros((y_max - y_min, x_max - x_min), dtype=bool) + + for i in range(len(pts)): + p1, p2 = pts[i], pts[(i + 1) % len(pts)] + y1, y2 = p1[1], p2[1] + if y1 == y2: + continue + if y1 > y2: + p1, p2, y1, y2 = p2, p1, p2[1], p1[1] + if not (edge_mask := (yy >= y1) & (yy < y2)).any(): + continue + mask ^= edge_mask & (xx >= p1[0] + (yy - y1) * (p2[0] - p1[0]) / (y2 - y1)) + + canvas_np[y_min:y_max, x_min:x_max][mask] = color + + @staticmethod + def ellipse2Poly(center, axes, angle, arc_start, arc_end, delta=1, **kwargs): + """Python implementation of cv2.ellipse2Poly.""" + axes = (axes[0] + 0.5, axes[1] + 0.5) # to better match cv2 output + angle = angle % 360 + if arc_start > arc_end: + arc_start, arc_end = arc_end, arc_start + while arc_start < 0: + arc_start, arc_end = arc_start + 360, arc_end + 360 + while arc_end > 360: + arc_end, arc_start = arc_end - 360, arc_start - 360 + if arc_end - arc_start > 360: + arc_start, arc_end = 0, 360 + + angle_rad = math.radians(angle) + alpha, beta = math.cos(angle_rad), math.sin(angle_rad) + pts = [] + for i in range(arc_start, arc_end + delta, delta): + theta_rad = math.radians(min(i, arc_end)) + x, y = axes[0] * math.cos(theta_rad), axes[1] * math.sin(theta_rad) + pts.append([int(round(center[0] + x * alpha - y * beta)), int(round(center[1] + x * beta + y * alpha))]) + + unique_pts, prev_pt = [], (float('inf'), float('inf')) + for pt in pts: + if (pt_tuple := tuple(pt)) != prev_pt: + unique_pts.append(pt) + prev_pt = pt_tuple + + return unique_pts if len(unique_pts) > 1 else [[center[0], center[1]], [center[0], center[1]]] + + def draw_wholebody_keypoints(self, canvas, keypoints, scores=None, threshold=0.3, + draw_body=True, draw_feet=True, draw_face=True, draw_hands=True, stick_width=4, face_point_size=3): + """ + Draw wholebody keypoints (134 keypoints after processing) in DWPose style. + + Expected keypoint format (after neck insertion and remapping): + - Body: 0-17 (18 keypoints in OpenPose format, neck at index 1) + - Foot: 18-23 (6 keypoints) + - Face: 24-91 (68 landmarks) + - Right hand: 92-112 (21 keypoints) + - Left hand: 113-133 (21 keypoints) + + Args: + canvas: The canvas to draw on (numpy array) + keypoints: Array of keypoint coordinates + scores: Optional confidence scores for each keypoint + threshold: Minimum confidence threshold for drawing keypoints + + Returns: + canvas: The canvas with keypoints drawn + """ + H, W, C = canvas.shape + + # Draw body limbs + if draw_body and len(keypoints) >= 18: + for i, limb in enumerate(self.body_limbSeq): + # Convert from 1-indexed to 0-indexed + idx1, idx2 = limb[0] - 1, limb[1] - 1 + + if idx1 >= 18 or idx2 >= 18: + continue + + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + Y = [keypoints[idx1][0], keypoints[idx2][0]] + X = [keypoints[idx1][1], keypoints[idx2][1]] + mX, mY = (X[0] + X[1]) / 2, (Y[0] + Y[1]) / 2 + length = math.sqrt((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) + + if length < 1: + continue + + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + + polygon = self.draw.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stick_width), int(angle), 0, 360, 1) + + self.draw.fillConvexPoly(canvas, polygon, self.colors[i % len(self.colors)]) + + # Draw body keypoints + if draw_body and len(keypoints) >= 18: + for i in range(18): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) + + # Draw foot keypoints (18-23, 6 keypoints) + if draw_feet and len(keypoints) >= 24: + for i in range(18, 24): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, self.colors[i % len(self.colors)], thickness=-1) + + # Draw right hand (92-112) + if draw_hands and len(keypoints) >= 113: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 92 + edge[0], 92 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) + + # Draw right hand keypoints + for i in range(92, 113): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + + # Draw left hand (113-133) + if draw_hands and len(keypoints) >= 134: + eps = 0.01 + for ie, edge in enumerate(self.hand_edges): + idx1, idx2 = 113 + edge[0], 113 + edge[1] + if scores is not None: + if scores[idx1] < threshold or scores[idx2] < threshold: + continue + + x1, y1 = int(keypoints[idx1][0]), int(keypoints[idx1][1]) + x2, y2 = int(keypoints[idx2][0]), int(keypoints[idx2][1]) + + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + if 0 <= x1 < W and 0 <= y1 < H and 0 <= x2 < W and 0 <= y2 < H: + # HSV to RGB conversion for rainbow colors + r, g, b = colorsys.hsv_to_rgb(ie / float(len(self.hand_edges)), 1.0, 1.0) + color = (int(r * 255), int(g * 255), int(b * 255)) + self.draw.line(canvas, (x1, y1), (x2, y2), color, thickness=2) + + # Draw left hand keypoints + for i in range(113, 134): + if scores is not None and i < len(scores) and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + + # Draw face keypoints (24-91) - white dots only, no lines + if draw_face and len(keypoints) >= 92: + eps = 0.01 + for i in range(24, 92): + if scores is not None and scores[i] < threshold: + continue + x, y = int(keypoints[i][0]), int(keypoints[i][1]) + if x > eps and y > eps and 0 <= x < W and 0 <= y < H: + self.draw.circle(canvas, (x, y), face_point_size, (255, 255, 255), thickness=-1) + + return canvas + +class SDPoseDrawKeypoints(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseDrawKeypoints", + category="image/preprocessors", + search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "pose"], + inputs=[ + io.Custom("POSE_KEYPOINT").Input("keypoints"), + io.Boolean.Input("draw_body", default=True), + io.Boolean.Input("draw_hands", default=True), + io.Boolean.Input("draw_face", default=True), + io.Boolean.Input("draw_feet", default=False), + io.Int.Input("stick_width", default=4, min=1, max=10, step=1), + io.Int.Input("face_point_size", default=3, min=1, max=10, step=1), + io.Float.Input("score_threshold", default=0.3, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, keypoints, draw_body, draw_hands, draw_face, draw_feet, stick_width, face_point_size, score_threshold) -> io.NodeOutput: + if not keypoints: + return io.NodeOutput(torch.zeros((1, 64, 64, 3), dtype=torch.float32)) + height = keypoints[0]["canvas_height"] + width = keypoints[0]["canvas_width"] + + def _parse(flat, n): + arr = np.array(flat, dtype=np.float32).reshape(n, 3) + return arr[:, :2], arr[:, 2] + + def _zeros(n): + return np.zeros((n, 2), dtype=np.float32), np.zeros(n, dtype=np.float32) + + pose_outputs = [] + drawer = KeypointDraw() + + for frame in tqdm(keypoints, desc="Drawing keypoints on frames"): + canvas = np.zeros((height, width, 3), dtype=np.uint8) + for person in frame["people"]: + body_kp, body_sc = _parse(person["pose_keypoints_2d"], 18) + foot_raw = person.get("foot_keypoints_2d") + foot_kp, foot_sc = _parse(foot_raw, 6) if foot_raw else _zeros(6) + face_kp, face_sc = _parse(person["face_keypoints_2d"], 70) + face_kp, face_sc = face_kp[:68], face_sc[:68] # drop appended eye kp; body already draws them + rhand_kp, rhand_sc = _parse(person["hand_right_keypoints_2d"], 21) + lhand_kp, lhand_sc = _parse(person["hand_left_keypoints_2d"], 21) + + kp = np.concatenate([body_kp, foot_kp, face_kp, rhand_kp, lhand_kp], axis=0) + sc = np.concatenate([body_sc, foot_sc, face_sc, rhand_sc, lhand_sc], axis=0) + + canvas = drawer.draw_wholebody_keypoints( + canvas, kp, sc, + threshold=score_threshold, + draw_body=draw_body, draw_feet=draw_feet, + draw_face=draw_face, draw_hands=draw_hands, + stick_width=stick_width, face_point_size=face_point_size, + ) + pose_outputs.append(canvas) + + pose_outputs_np = np.stack(pose_outputs) if len(pose_outputs) > 1 else np.expand_dims(pose_outputs[0], 0) + final_pose_output = torch.from_numpy(pose_outputs_np).float() / 255.0 + return io.NodeOutput(final_pose_output) + +class SDPoseKeypointExtractor(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseKeypointExtractor", + category="image/preprocessors", + search_aliases=["openpose", "pose detection", "preprocessor", "keypoints", "sdpose"], + description="Extract pose keypoints from images using the SDPose model: https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints", + inputs=[ + io.Model.Input("model"), + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Int.Input("batch_size", default=16, min=1, max=10000, step=1), + io.BoundingBox.Input("bboxes", optional=True, force_input=True, tooltip="Optional bounding boxes for more accurate detections. Required for multi-person detection."), + ], + outputs=[ + io.Custom("POSE_KEYPOINT").Output("keypoints", tooltip="Keypoints in OpenPose frame format (canvas_width, canvas_height, people)"), + ], + ) + + @classmethod + def execute(cls, model, vae, image, batch_size, bboxes=None) -> io.NodeOutput: + + height, width = image.shape[-3], image.shape[-2] + context = LotusConditioning().execute().result[0] + + # Use output_block_patch to capture the last 640-channel feature + def output_patch(h, hsp, transformer_options): + nonlocal captured_feat + if h.shape[1] == 640: # Capture the features for wholebody + captured_feat = h.clone() + return h, hsp + + model_clone = model.clone() + model_clone.model_options["transformer_options"] = {"patches": {"output_block_patch": [output_patch]}} + + if not hasattr(model.model.diffusion_model, 'heatmap_head'): + raise ValueError("The provided model does not have a heatmap_head. Please use SDPose model from here https://huggingface.co/Comfy-Org/SDPose/tree/main/checkpoints.") + + head = model.model.diffusion_model.heatmap_head + total_images = image.shape[0] + captured_feat = None + + model_h = int(head.heatmap_size[0]) * 4 # e.g. 192 * 4 = 768 + model_w = int(head.heatmap_size[1]) * 4 # e.g. 256 * 4 = 1024 + + def _run_on_latent(latent_batch): + """Run one forward pass and return (keypoints_list, scores_list) for the batch.""" + nonlocal captured_feat + captured_feat = None + _ = comfy.sample.sample( + model_clone, + noise=torch.zeros_like(latent_batch), + steps=1, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=context, negative=context, + latent_image=latent_batch, disable_noise=True, disable_pbar=True, + ) + return head(captured_feat) # keypoints_batch, scores_batch + + # all_keypoints / all_scores are lists-of-lists: + # outer index = input image index + # inner index = detected person (one per bbox, or one for full-image) + all_keypoints = [] # shape: [n_images][n_persons] + all_scores = [] # shape: [n_images][n_persons] + pbar = comfy.utils.ProgressBar(total_images) + + if bboxes is not None: + if not isinstance(bboxes, list): + bboxes = [[bboxes]] + elif len(bboxes) == 0: + bboxes = [None] * total_images + # --- bbox-crop mode: one forward pass per crop ------------------------- + for img_idx in tqdm(range(total_images), desc="Extracting keypoints from crops"): + img = image[img_idx:img_idx + 1] # (1, H, W, C) + # Broadcasting: if fewer bbox lists than images, repeat the last one. + img_bboxes = bboxes[min(img_idx, len(bboxes) - 1)] if bboxes else None + + img_keypoints = [] + img_scores = [] + + if img_bboxes: + for bbox in img_bboxes: + x1 = max(0, int(bbox["x"])) + y1 = max(0, int(bbox["y"])) + x2 = min(width, int(bbox["x"] + bbox["width"])) + y2 = min(height, int(bbox["y"] + bbox["height"])) + + if x2 <= x1 or y2 <= y1: + continue + + crop_h_px, crop_w_px = y2 - y1, x2 - x1 + crop = img[:, y1:y2, x1:x2, :] # (1, crop_h, crop_w, C) + + # scale to fit inside (model_h, model_w) while preserving aspect ratio, then pad to exact model size. + scale = min(model_h / crop_h_px, model_w / crop_w_px) + scaled_h, scaled_w = int(round(crop_h_px * scale)), int(round(crop_w_px * scale)) + pad_top, pad_left = (model_h - scaled_h) // 2, (model_w - scaled_w) // 2 + + crop_chw = crop.permute(0, 3, 1, 2).float() # BHWC → BCHW + scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled") + padded = torch.zeros(1, scaled.shape[1], model_h, model_w, dtype=scaled.dtype, device=scaled.device) + padded[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled + crop_resized = padded.permute(0, 2, 3, 1) # BCHW → BHWC + + latent_crop = vae.encode(crop_resized) + kp_batch, sc_batch = _run_on_latent(latent_crop) + kp, sc = kp_batch[0], sc_batch[0] # (K, 2), coords in model pixel space + + # remove padding offset, undo scale, offset to full-image coordinates. + kp = kp.copy() if isinstance(kp, np.ndarray) else np.array(kp, dtype=np.float32) + kp[..., 0] = (kp[..., 0] - pad_left) / scale + x1 + kp[..., 1] = (kp[..., 1] - pad_top) / scale + y1 + + img_keypoints.append(kp) + img_scores.append(sc) + else: + # No bboxes for this image – run on the full image + latent_img = vae.encode(img) + kp_batch, sc_batch = _run_on_latent(latent_img) + img_keypoints.append(kp_batch[0]) + img_scores.append(sc_batch[0]) + + all_keypoints.append(img_keypoints) + all_scores.append(img_scores) + pbar.update(1) + + else: # full-image mode, batched + tqdm_pbar = tqdm(total=total_images, desc="Extracting keypoints") + for batch_start in range(0, total_images, batch_size): + batch_end = min(batch_start + batch_size, total_images) + latent_batch = vae.encode(image[batch_start:batch_end]) + + kp_batch, sc_batch = _run_on_latent(latent_batch) + + for kp, sc in zip(kp_batch, sc_batch): + all_keypoints.append([kp]) + all_scores.append([sc]) + tqdm_pbar.update(1) + + pbar.update(batch_end - batch_start) + + openpose_frames = _to_openpose_frames(all_keypoints, all_scores, height, width) + return io.NodeOutput(openpose_frames) + + +def get_face_bboxes(kp2ds, scale, image_shape): + h, w = image_shape + kp2ds_face = kp2ds.copy()[1:] * (w, h) + + min_x, min_y = np.min(kp2ds_face, axis=0) + max_x, max_y = np.max(kp2ds_face, axis=0) + + initial_width = max_x - min_x + initial_height = max_y - min_y + + if initial_width <= 0 or initial_height <= 0: + return [0, 0, 0, 0] + + initial_area = initial_width * initial_height + + expanded_area = initial_area * scale + + new_width = np.sqrt(expanded_area * (initial_width / initial_height)) + new_height = np.sqrt(expanded_area * (initial_height / initial_width)) + + delta_width = (new_width - initial_width) / 2 + delta_height = (new_height - initial_height) / 4 + + expanded_min_x = max(min_x - delta_width, 0) + expanded_max_x = min(max_x + delta_width, w) + expanded_min_y = max(min_y - 3 * delta_height, 0) + expanded_max_y = min(max_y + delta_height, h) + + return [int(expanded_min_x), int(expanded_max_x), int(expanded_min_y), int(expanded_max_y)] + +class SDPoseFaceBBoxes(io.ComfyNode): + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SDPoseFaceBBoxes", + category="image/preprocessors", + search_aliases=["face bbox", "face bounding box", "pose", "keypoints"], + inputs=[ + io.Custom("POSE_KEYPOINT").Input("keypoints"), + io.Float.Input("scale", default=1.5, min=1.0, max=10.0, step=0.1, tooltip="Multiplier for the bounding box area around each detected face."), + io.Boolean.Input("force_square", default=True, tooltip="Expand the shorter bbox axis so the crop region is always square."), + ], + outputs=[ + io.BoundingBox.Output("bboxes", tooltip="Face bounding boxes per frame, compatible with SDPoseKeypointExtractor bboxes input."), + ], + ) + + @classmethod + def execute(cls, keypoints, scale, force_square) -> io.NodeOutput: + all_bboxes = [] + for frame in keypoints: + h = frame["canvas_height"] + w = frame["canvas_width"] + frame_bboxes = [] + for person in frame["people"]: + face_flat = person.get("face_keypoints_2d", []) + if not face_flat: + continue + # Parse absolute-pixel face keypoints (70 kp: 68 landmarks + REye + LEye) + face_arr = np.array(face_flat, dtype=np.float32).reshape(-1, 3) + face_xy = face_arr[:, :2] # (70, 2) in absolute pixels + + kp_norm = face_xy / np.array([w, h], dtype=np.float32) + kp_padded = np.vstack([np.zeros((1, 2), dtype=np.float32), kp_norm]) # (71, 2) + + x1, x2, y1, y2 = get_face_bboxes(kp_padded, scale, (h, w)) + if x2 > x1 and y2 > y1: + if force_square: + bw, bh = x2 - x1, y2 - y1 + if bw != bh: + side = max(bw, bh) + cx, cy = (x1 + x2) // 2, (y1 + y2) // 2 + half = side // 2 + x1 = max(0, cx - half) + y1 = max(0, cy - half) + x2 = min(w, x1 + side) + y2 = min(h, y1 + side) + # Re-anchor if clamped + x1 = max(0, x2 - side) + y1 = max(0, y2 - side) + frame_bboxes.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1}) + + all_bboxes.append(frame_bboxes) + + return io.NodeOutput(all_bboxes) + + +class CropByBBoxes(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CropByBBoxes", + category="image/preprocessors", + search_aliases=["crop", "face crop", "bbox crop", "pose", "bounding box"], + description="Crop and resize regions from the input image batch based on provided bounding boxes.", + inputs=[ + io.Image.Input("image"), + io.BoundingBox.Input("bboxes", force_input=True), + io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."), + io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."), + io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."), + ], + outputs=[ + io.Image.Output(tooltip="All crops stacked into a single image batch."), + ], + ) + + @classmethod + def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput: + total_frames = image.shape[0] + img_h = image.shape[1] + img_w = image.shape[2] + num_ch = image.shape[3] + + if not isinstance(bboxes, list): + bboxes = [[bboxes]] + elif len(bboxes) == 0: + return io.NodeOutput(image) + + crops = [] + + for frame_idx in range(total_frames): + frame_bboxes = bboxes[min(frame_idx, len(bboxes) - 1)] + if not frame_bboxes: + continue + + frame_chw = image[frame_idx].permute(2, 0, 1).unsqueeze(0) # BHWC → BCHW (1, C, H, W) + + # Union all bboxes for this frame into a single crop region + x1 = min(int(b["x"]) for b in frame_bboxes) + y1 = min(int(b["y"]) for b in frame_bboxes) + x2 = max(int(b["x"] + b["width"]) for b in frame_bboxes) + y2 = max(int(b["y"] + b["height"]) for b in frame_bboxes) + + if padding > 0: + x1 = max(0, x1 - padding) + y1 = max(0, y1 - padding) + x2 = min(img_w, x2 + padding) + y2 = min(img_h, y2 + padding) + + x1, x2 = max(0, x1), min(img_w, x2) + y1, y2 = max(0, y1), min(img_h, y2) + + # Fallback for empty/degenerate crops + if x2 <= x1 or y2 <= y1: + fallback_size = int(min(img_h, img_w) * 0.3) + fb_x1 = max(0, (img_w - fallback_size) // 2) + fb_y1 = max(0, int(img_h * 0.1)) + fb_x2 = min(img_w, fb_x1 + fallback_size) + fb_y2 = min(img_h, fb_y1 + fallback_size) + if fb_x2 <= fb_x1 or fb_y2 <= fb_y1: + crops.append(torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)) + continue + x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2 + + crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w) + resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled") + crops.append(resized) + + if not crops: + return io.NodeOutput(image) + + out_images = torch.cat(crops, dim=0).permute(0, 2, 3, 1) # (N, H, W, C) + return io.NodeOutput(out_images) + + +class SDPoseExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SDPoseKeypointExtractor, + SDPoseDrawKeypoints, + SDPoseFaceBBoxes, + CropByBBoxes, + ] + +async def comfy_entrypoint() -> SDPoseExtension: + return SDPoseExtension() diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py index 31b373370..5877719d3 100644 --- a/comfy_extras/nodes_sdupscale.py +++ b/comfy_extras/nodes_sdupscale.py @@ -15,7 +15,7 @@ class SD_4XUpscale_Conditioning(io.ComfyNode): io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), io.Float.Input("scale_ratio", default=4.0, min=0.0, max=10.0, step=0.01), - io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("noise_augmentation", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True), ], outputs=[ io.Conditioning.Output(display_name="positive"), diff --git a/comfy_extras/nodes_slg.py b/comfy_extras/nodes_slg.py index f462faa8f..8cc1f551e 100644 --- a/comfy_extras/nodes_slg.py +++ b/comfy_extras/nodes_slg.py @@ -21,11 +21,11 @@ class SkipLayerGuidanceDiT(io.ComfyNode): is_experimental=True, inputs=[ io.Model.Input("model"), - io.String.Input("double_layers", default="7, 8, 9"), - io.String.Input("single_layers", default="7, 8, 9"), + io.String.Input("double_layers", default="7, 8, 9", advanced=True), + io.String.Input("single_layers", default="7, 8, 9", advanced=True), io.Float.Input("scale", default=3.0, min=0.0, max=10.0, step=0.1), - io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001), - io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001), + io.Float.Input("start_percent", default=0.01, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("end_percent", default=0.15, min=0.0, max=1.0, step=0.001, advanced=True), io.Float.Input("rescaling_scale", default=0.0, min=0.0, max=10.0, step=0.01), ], outputs=[ @@ -101,10 +101,10 @@ class SkipLayerGuidanceDiTSimple(io.ComfyNode): is_experimental=True, inputs=[ io.Model.Input("model"), - io.String.Input("double_layers", default="7, 8, 9"), - io.String.Input("single_layers", default="7, 8, 9"), - io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), - io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + io.String.Input("double_layers", default="7, 8, 9", advanced=True), + io.String.Input("single_layers", default="7, 8, 9", advanced=True), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True), ], outputs=[ io.Model.Output(), diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index c6d8a683d..829c837a1 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -75,8 +75,8 @@ class StableZero123_Conditioning_Batched(io.ComfyNode): io.Int.Input("batch_size", default=1, min=1, max=4096), io.Float.Input("elevation", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), io.Float.Input("azimuth", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), - io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False), - io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False) + io.Float.Input("elevation_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False, advanced=True), + io.Float.Input("azimuth_batch_increment", default=0.0, min=-180.0, max=180.0, step=0.1, round=False, advanced=True) ], outputs=[ io.Conditioning.Output(display_name="positive"), diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 04c0b366a..8c1aebca9 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -33,7 +33,7 @@ class StableCascade_EmptyLatentImage(io.ComfyNode): inputs=[ io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), - io.Int.Input("compression", default=42, min=4, max=128, step=1), + io.Int.Input("compression", default=42, min=4, max=128, step=1, advanced=True), io.Int.Input("batch_size", default=1, min=1, max=4096), ], outputs=[ @@ -62,7 +62,7 @@ class StableCascade_StageC_VAEEncode(io.ComfyNode): inputs=[ io.Image.Input("image"), io.Vae.Input("vae"), - io.Int.Input("compression", default=42, min=4, max=128, step=1), + io.Int.Input("compression", default=42, min=4, max=128, step=1, advanced=True), ], outputs=[ io.Latent.Output(display_name="stage_c"), diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 571d89f62..b4e5f148a 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -11,6 +11,7 @@ class StringConcatenate(io.ComfyNode): node_id="StringConcatenate", display_name="Concatenate", category="utils/string", + search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), @@ -31,6 +32,7 @@ class StringSubstring(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringSubstring", + search_aliases=["extract text", "text portion"], display_name="Substring", category="utils/string", inputs=[ @@ -53,6 +55,7 @@ class StringLength(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringLength", + search_aliases=["character count", "text size"], display_name="Length", category="utils/string", inputs=[ @@ -73,6 +76,7 @@ class CaseConverter(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="CaseConverter", + search_aliases=["text case", "uppercase", "lowercase", "capitalize"], display_name="Case Converter", category="utils/string", inputs=[ @@ -105,6 +109,7 @@ class StringTrim(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringTrim", + search_aliases=["clean whitespace", "remove whitespace"], display_name="Trim", category="utils/string", inputs=[ @@ -135,6 +140,7 @@ class StringReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringReplace", + search_aliases=["find and replace", "substitute", "swap text"], display_name="Replace", category="utils/string", inputs=[ @@ -157,12 +163,13 @@ class StringContains(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringContains", + search_aliases=["text includes", "string includes"], display_name="Contains", category="utils/string", inputs=[ io.String.Input("string", multiline=True), io.String.Input("substring", multiline=True), - io.Boolean.Input("case_sensitive", default=True), + io.Boolean.Input("case_sensitive", default=True, advanced=True), ], outputs=[ io.Boolean.Output(display_name="contains"), @@ -184,13 +191,14 @@ class StringCompare(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StringCompare", + search_aliases=["text match", "string equals", "starts with", "ends with"], display_name="Compare", category="utils/string", inputs=[ io.String.Input("string_a", multiline=True), io.String.Input("string_b", multiline=True), io.Combo.Input("mode", options=["Starts With", "Ends With", "Equal"]), - io.Boolean.Input("case_sensitive", default=True), + io.Boolean.Input("case_sensitive", default=True, advanced=True), ], outputs=[ io.Boolean.Output(), @@ -219,14 +227,15 @@ class RegexMatch(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexMatch", + search_aliases=["pattern match", "text contains", "string match"], display_name="Regex Match", category="utils/string", inputs=[ io.String.Input("string", multiline=True), io.String.Input("regex_pattern", multiline=True), - io.Boolean.Input("case_insensitive", default=True), - io.Boolean.Input("multiline", default=False), - io.Boolean.Input("dotall", default=False), + io.Boolean.Input("case_insensitive", default=True, advanced=True), + io.Boolean.Input("multiline", default=False, advanced=True), + io.Boolean.Input("dotall", default=False, advanced=True), ], outputs=[ io.Boolean.Output(display_name="matches"), @@ -259,16 +268,17 @@ class RegexExtract(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexExtract", + search_aliases=["pattern extract", "text parser", "parse text"], display_name="Regex Extract", category="utils/string", inputs=[ io.String.Input("string", multiline=True), io.String.Input("regex_pattern", multiline=True), io.Combo.Input("mode", options=["First Match", "All Matches", "First Group", "All Groups"]), - io.Boolean.Input("case_insensitive", default=True), - io.Boolean.Input("multiline", default=False), - io.Boolean.Input("dotall", default=False), - io.Int.Input("group_index", default=1, min=0, max=100), + io.Boolean.Input("case_insensitive", default=True, advanced=True), + io.Boolean.Input("multiline", default=False, advanced=True), + io.Boolean.Input("dotall", default=False, advanced=True), + io.Int.Input("group_index", default=1, min=0, max=100, advanced=True), ], outputs=[ io.String.Output(), @@ -333,6 +343,7 @@ class RegexReplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RegexReplace", + search_aliases=["pattern replace", "find and replace", "substitution"], display_name="Regex Replace", category="utils/string", description="Find and replace text using regex patterns.", @@ -340,10 +351,10 @@ class RegexReplace(io.ComfyNode): io.String.Input("string", multiline=True), io.String.Input("regex_pattern", multiline=True), io.String.Input("replace", multiline=True), - io.Boolean.Input("case_insensitive", default=True, optional=True), - io.Boolean.Input("multiline", default=False, optional=True), - io.Boolean.Input("dotall", default=False, optional=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."), - io.Int.Input("count", default=0, min=0, max=100, optional=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."), + io.Boolean.Input("case_insensitive", default=True, optional=True, advanced=True), + io.Boolean.Input("multiline", default=False, optional=True, advanced=True), + io.Boolean.Input("dotall", default=False, optional=True, advanced=True, tooltip="When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."), + io.Int.Input("count", default=0, min=0, max=100, optional=True, advanced=True, tooltip="Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."), ], outputs=[ io.String.Output(), diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py new file mode 100644 index 000000000..14cff14a6 --- /dev/null +++ b/comfy_extras/nodes_textgen.py @@ -0,0 +1,176 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + +class TextGenerate(io.ComfyNode): + @classmethod + def define_schema(cls): + # Define dynamic combo options for sampling mode + sampling_options = [ + io.DynamicCombo.Option( + key="on", + inputs=[ + io.Float.Input("temperature", default=0.7, min=0.01, max=2.0, step=0.000001), + io.Int.Input("top_k", default=64, min=0, max=1000), + io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01), + io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01), + io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01), + io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff), + ] + ), + io.DynamicCombo.Option( + key="off", + inputs=[] + ), + ] + + return io.Schema( + node_id="TextGenerate", + category="textgen/", + search_aliases=["LLM", "gemma"], + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), + io.Image.Input("image", optional=True), + io.Int.Input("max_length", default=256, min=1, max=2048), + io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), + ], + outputs=[ + io.String.Output(display_name="generated_text"), + ], + ) + + @classmethod + def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput: + + tokens = clip.tokenize(prompt, image=image, skip_template=False, min_length=1) + + # Get sampling parameters from dynamic combo + do_sample = sampling_mode.get("sampling_mode") == "on" + temperature = sampling_mode.get("temperature", 1.0) + top_k = sampling_mode.get("top_k", 50) + top_p = sampling_mode.get("top_p", 1.0) + min_p = sampling_mode.get("min_p", 0.0) + seed = sampling_mode.get("seed", None) + repetition_penalty = sampling_mode.get("repetition_penalty", 1.0) + + generated_ids = clip.generate( + tokens, + do_sample=do_sample, + max_length=max_length, + temperature=temperature, + top_k=top_k, + top_p=top_p, + min_p=min_p, + repetition_penalty=repetition_penalty, + seed=seed + ) + + generated_text = clip.decode(generated_ids, skip_special_tokens=True) + return io.NodeOutput(generated_text) + + +LTX2_T2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model. +#### Guidelines +- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio). + - If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc. + - For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters. +- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements. +- Maintain chronological flow: use temporal connectors ("as," "then," "while"). +- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present"). +- Speech (only when requested): + - For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'"). + - Specify language if not English and accent if relevant. +- Style: Include visual style at the beginning: "Style: