diff --git a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat
similarity index 66%
rename from .ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat
rename to .ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat
index cece0aeb2..94ad31942 100755
--- a/.ci/windows_amd_base_files/run_amd_gpu_disable_smart_memory.bat
+++ b/.ci/windows_amd_base_files/run_amd_gpu_enable_dynamic_vram.bat
@@ -1,2 +1,2 @@
-.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory
+.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram
pause
diff --git a/.github/workflows/openapi-lint.yml b/.github/workflows/openapi-lint.yml
new file mode 100644
index 000000000..be949de2a
--- /dev/null
+++ b/.github/workflows/openapi-lint.yml
@@ -0,0 +1,31 @@
+name: OpenAPI Lint
+
+on:
+ pull_request:
+ paths:
+ - 'openapi.yaml'
+ - '.spectral.yaml'
+ - '.github/workflows/openapi-lint.yml'
+
+permissions:
+ contents: read
+
+jobs:
+ spectral:
+ name: Run Spectral
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+
+ - name: Set up Node.js
+ uses: actions/setup-node@v4
+ with:
+ node-version: '20'
+
+ - name: Install Spectral
+ run: npm install -g @stoplight/spectral-cli@6
+
+ - name: Lint openapi.yaml
+ run: spectral lint openapi.yaml --ruleset .spectral.yaml --fail-severity=error
diff --git a/.github/workflows/stable-release.yml b/.github/workflows/stable-release.yml
index f501b7b31..bc64ed74d 100644
--- a/.github/workflows/stable-release.yml
+++ b/.github/workflows/stable-release.yml
@@ -145,6 +145,8 @@ jobs:
cp -r ComfyUI/.ci/windows_${{ inputs.rel_name }}_base_files/* ./
cp ../update_comfyui_and_python_dependencies.bat ./update/
+ echo 'local-portable' > ComfyUI/.comfy_environment
+
cd ..
"C:\Program Files\7-Zip\7z.exe" a -t7z -m0=lzma2 -mx=9 -mfb=128 -md=768m -ms=on -mf=BCJ2 ComfyUI_windows_portable.7z ComfyUI_windows_portable
diff --git a/.github/workflows/tag-dispatch-cloud.yml b/.github/workflows/tag-dispatch-cloud.yml
new file mode 100644
index 000000000..53a0e91d6
--- /dev/null
+++ b/.github/workflows/tag-dispatch-cloud.yml
@@ -0,0 +1,45 @@
+name: Tag Dispatch to Cloud
+
+on:
+ push:
+ tags:
+ - 'v*'
+
+jobs:
+ dispatch-cloud:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Send repository dispatch to cloud
+ env:
+ DISPATCH_TOKEN: ${{ secrets.CLOUD_REPO_DISPATCH_TOKEN }}
+ RELEASE_TAG: ${{ github.ref_name }}
+ run: |
+ set -euo pipefail
+
+ if [ -z "${DISPATCH_TOKEN:-}" ]; then
+ echo "::error::CLOUD_REPO_DISPATCH_TOKEN is required but not set."
+ exit 1
+ fi
+
+ RELEASE_URL="https://github.com/${{ github.repository }}/releases/tag/${RELEASE_TAG}"
+
+ PAYLOAD="$(jq -n \
+ --arg release_tag "$RELEASE_TAG" \
+ --arg release_url "$RELEASE_URL" \
+ '{
+ event_type: "comfyui_tag_pushed",
+ client_payload: {
+ release_tag: $release_tag,
+ release_url: $release_url
+ }
+ }')"
+
+ curl -fsSL \
+ -X POST \
+ -H "Accept: application/vnd.github+json" \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer ${DISPATCH_TOKEN}" \
+ https://api.github.com/repos/Comfy-Org/cloud/dispatches \
+ -d "$PAYLOAD"
+
+ echo "✅ Dispatched ComfyUI tag ${RELEASE_TAG} to Comfy-Org/cloud"
diff --git a/.gitignore b/.gitignore
index 0ab4ba75e..fc426eda4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,3 +23,4 @@ web_custom_versions/
.DS_Store
filtered-openapi.yaml
uv.lock
+.comfy_environment
diff --git a/.spectral.yaml b/.spectral.yaml
new file mode 100644
index 000000000..4bb4a4a94
--- /dev/null
+++ b/.spectral.yaml
@@ -0,0 +1,91 @@
+extends:
+ - spectral:oas
+
+# Severity levels: error, warn, info, hint, off
+# Rules from the built-in "spectral:oas" ruleset are active by default.
+# Below we tune severity and add custom rules for our conventions.
+#
+# This ruleset mirrors Comfy-Org/cloud/.spectral.yaml so specs across the
+# organization are linted against a single consistent standard.
+
+rules:
+ # -----------------------------------------------------------------------
+ # Built-in rule severity overrides
+ # -----------------------------------------------------------------------
+ operation-operationId: error
+ operation-description: warn
+ operation-tag-defined: error
+ info-contact: off
+ info-description: warn
+ no-eval-in-markdown: error
+ no-$ref-siblings: error
+
+ # -----------------------------------------------------------------------
+ # Custom rules: naming conventions
+ # -----------------------------------------------------------------------
+
+ # Property names should be snake_case
+ property-name-snake-case:
+ description: Property names must be snake_case
+ severity: warn
+ given: "$.components.schemas.*.properties[*]~"
+ then:
+ function: pattern
+ functionOptions:
+ match: "^[a-z][a-z0-9]*(_[a-z0-9]+)*$"
+
+ # Operation IDs should be camelCase
+ operation-id-camel-case:
+ description: Operation IDs must be camelCase
+ severity: warn
+ given: "$.paths.*.*.operationId"
+ then:
+ function: pattern
+ functionOptions:
+ match: "^[a-z][a-zA-Z0-9]*$"
+
+ # -----------------------------------------------------------------------
+ # Custom rules: response conventions
+ # -----------------------------------------------------------------------
+
+ # Error responses (4xx, 5xx) should use a consistent shape
+ error-response-schema:
+ description: Error responses should reference a standard error schema
+ severity: hint
+ given: "$.paths.*.*.responses[?(@property >= '400' && @property < '600')].content['application/json'].schema"
+ then:
+ field: "$ref"
+ function: truthy
+
+ # All 2xx responses with JSON body should have a schema
+ response-schema-defined:
+ description: Success responses with JSON content should define a schema
+ severity: warn
+ given: "$.paths.*.*.responses[?(@property >= '200' && @property < '300')].content['application/json']"
+ then:
+ field: schema
+ function: truthy
+
+ # -----------------------------------------------------------------------
+ # Custom rules: best practices
+ # -----------------------------------------------------------------------
+
+ # Path parameters must have a description
+ path-param-description:
+ description: Path parameters should have a description
+ severity: warn
+ given:
+ - "$.paths.*.parameters[?(@.in == 'path')]"
+ - "$.paths.*.*.parameters[?(@.in == 'path')]"
+ then:
+ field: description
+ function: truthy
+
+ # Schemas should have a description
+ schema-description:
+ description: Component schemas should have a description
+ severity: hint
+ given: "$.components.schemas.*"
+ then:
+ field: description
+ function: truthy
diff --git a/CODEOWNERS b/CODEOWNERS
index 4d5448636..946dbf946 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,2 +1,2 @@
# Admins
-* @comfyanonymous @kosinkadink @guill
+* @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai
diff --git a/README.md b/README.md
index f05311421..0fd317d0a 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
# ComfyUI
-**The most powerful and modular visual AI engine and application.**
+**The most powerful and modular AI engine for content creation.**
[![Website][website-shield]][website-url]
@@ -31,10 +31,16 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
-
+
+
-ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
+ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
+- ComfyUI natively supports the latest open-source state of the art models.
+- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
+- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
+- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
+- It integrates seamlessly into production pipelines with our API endpoints.
## Get Started
@@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
+ - Ernie Image
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
@@ -126,7 +133,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- - Releases a new stable version (e.g., v0.7.0) roughly every week.
+ - Releases a new major stable version (e.g., v0.7.0) roughly every 2 weeks.
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
- Minor versions will be used for releases off the master branch.
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
@@ -193,13 +200,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
-#### Alternative Downloads:
+#### All Official Portable Downloads:
[Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
-[Experimental portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
+[Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
-[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
+[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above).
+
+[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI?
diff --git a/app/frontend_management.py b/app/frontend_management.py
index f753ef0de..7108bd35a 100644
--- a/app/frontend_management.py
+++ b/app/frontend_management.py
@@ -27,7 +27,7 @@ def frontend_install_warning_message():
return f"""
{get_missing_requirements_message()}
-This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
+The ComfyUI frontend is shipped in a pip package so it needs to be updated separately from the ComfyUI code.
""".strip()
def parse_version(version: str) -> tuple[int, int, int]:
diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py
index d9aab5b22..72e8ac2b1 100644
--- a/app/node_replace_manager.py
+++ b/app/node_replace_manager.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import logging
+
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
@@ -31,8 +33,22 @@ class NodeReplaceManager:
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
- """Register a node replacement mapping."""
- self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
+ """Register a node replacement mapping.
+
+ Idempotent: if a replacement with the same (old_node_id, new_node_id)
+ is already registered, the duplicate is ignored. This prevents stale
+ entries from accumulating when custom nodes are reloaded in the same
+ process (e.g. via ComfyUI-Manager).
+ """
+ existing = self._replacements.setdefault(node_replace.old_node_id, [])
+ for entry in existing:
+ if entry.new_node_id == node_replace.new_node_id:
+ logging.debug(
+ "Node replacement %s -> %s already registered, ignoring duplicate.",
+ node_replace.old_node_id, node_replace.new_node_id,
+ )
+ return
+ existing.append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
diff --git a/app/user_manager.py b/app/user_manager.py
index e18afb71b..0517b3344 100644
--- a/app/user_manager.py
+++ b/app/user_manager.py
@@ -28,8 +28,8 @@ def get_file_info(path: str, relative_to: str) -> FileInfo:
return {
"path": os.path.relpath(path, relative_to).replace(os.sep, '/'),
"size": os.path.getsize(path),
- "modified": os.path.getmtime(path),
- "created": os.path.getctime(path)
+ "modified": int(os.path.getmtime(path) * 1000),
+ "created": int(os.path.getctime(path) * 1000),
}
diff --git a/blueprints/Brightness and Contrast.json b/blueprints/Brightness and Contrast.json
index 90bfe999d..78fc52f29 100644
--- a/blueprints/Brightness and Contrast.json
+++ b/blueprints/Brightness and Contrast.json
@@ -431,9 +431,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Adjusts image brightness and contrast using a real-time GPU fragment shader."
}
]
},
"extra": {}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Canny to Image (Z-Image-Turbo).json b/blueprints/Canny to Image (Z-Image-Turbo).json
index ff9717308..14deb64cc 100644
--- a/blueprints/Canny to Image (Z-Image-Turbo).json
+++ b/blueprints/Canny to Image (Z-Image-Turbo).json
@@ -162,7 +162,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Canny to Image (Z-Image-Turbo)",
+ "name": "Canny to Image (Z-Image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@@ -1553,7 +1553,8 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
- "category": "Image generation and editing/Canny to image"
+ "category": "Image generation and editing/Canny to image",
+ "description": "Generates an image from a Canny edge map using Z-Image-Turbo, with text conditioning."
}
]
},
@@ -1574,4 +1575,4 @@
}
},
"version": 0.4
-}
+}
\ No newline at end of file
diff --git a/blueprints/Canny to Video (LTX 2.0).json b/blueprints/Canny to Video (LTX 2.0).json
index fae8321b9..a9682c8a4 100644
--- a/blueprints/Canny to Video (LTX 2.0).json
+++ b/blueprints/Canny to Video (LTX 2.0).json
@@ -192,7 +192,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Canny to Video (LTX 2.0)",
+ "name": "Canny to Video (LTX 2.0)",
"inputNode": {
"id": -10,
"bounding": [
@@ -3600,7 +3600,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Video generation and editing/Canny to video"
+ "category": "Video generation and editing/Canny to video",
+ "description": "Generates video from Canny edge maps using LTX-2, with optional synchronized audio."
}
]
},
@@ -3616,4 +3617,4 @@
}
},
"version": 0.4
-}
+}
\ No newline at end of file
diff --git a/blueprints/Chromatic Aberration.json b/blueprints/Chromatic Aberration.json
index ae8037b1b..893fb1190 100644
--- a/blueprints/Chromatic Aberration.json
+++ b/blueprints/Chromatic Aberration.json
@@ -377,8 +377,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Adds lens-style chromatic aberration (color fringing) using a real-time GPU fragment shader."
}
]
}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Color Adjustment.json b/blueprints/Color Adjustment.json
index 622bf28af..5abbf8baa 100644
--- a/blueprints/Color Adjustment.json
+++ b/blueprints/Color Adjustment.json
@@ -596,7 +596,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Adjusts saturation, temperature, tint, and vibrance using a real-time GPU fragment shader."
}
]
}
diff --git a/blueprints/Color Balance.json b/blueprints/Color Balance.json
index 21d6319ed..d921eab37 100644
--- a/blueprints/Color Balance.json
+++ b/blueprints/Color Balance.json
@@ -1129,7 +1129,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Balances colors across shadows, midtones, and highlights using a real-time GPU fragment shader."
}
]
}
diff --git a/blueprints/Color Curves.json b/blueprints/Color Curves.json
index 1461cf396..b9bfb7029 100644
--- a/blueprints/Color Curves.json
+++ b/blueprints/Color Curves.json
@@ -608,7 +608,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Fine-tunes tone and color with per-channel curve adjustments using a real-time GPU fragment shader."
}
]
}
diff --git a/blueprints/Crop Images 2x2.json b/blueprints/Crop Images 2x2.json
index 2aa42cfc3..99b89b608 100644
--- a/blueprints/Crop Images 2x2.json
+++ b/blueprints/Crop Images 2x2.json
@@ -1609,7 +1609,8 @@
}
],
"extra": {},
- "category": "Image Tools/Crop"
+ "category": "Image Tools/Crop",
+ "description": "Splits an image into a 2×2 grid of four equal tiles."
}
]
},
diff --git a/blueprints/Crop Images 3x3.json b/blueprints/Crop Images 3x3.json
index 3a3615ac8..6ac636da4 100644
--- a/blueprints/Crop Images 3x3.json
+++ b/blueprints/Crop Images 3x3.json
@@ -2946,7 +2946,8 @@
}
],
"extra": {},
- "category": "Image Tools/Crop"
+ "category": "Image Tools/Crop",
+ "description": "Splits an image into a 3×3 grid of nine equal tiles."
}
]
},
diff --git a/blueprints/Depth to Image (Z-Image-Turbo).json b/blueprints/Depth to Image (Z-Image-Turbo).json
index 4f69a8149..fe9ef0f72 100644
--- a/blueprints/Depth to Image (Z-Image-Turbo).json
+++ b/blueprints/Depth to Image (Z-Image-Turbo).json
@@ -1579,7 +1579,8 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
- "category": "Image generation and editing/Depth to image"
+ "category": "Image generation and editing/Depth to image",
+ "description": "Generates an image from a depth map using Z-Image-Turbo with text conditioning."
},
{
"id": "458bdf3c-4b58-421c-af50-c9c663a4d74c",
@@ -2461,7 +2462,8 @@
]
},
"workflowRendererVersion": "LG"
- }
+ },
+ "description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
}
]
},
diff --git a/blueprints/Depth to Video (ltx 2.0).json b/blueprints/Depth to Video (ltx 2.0).json
index f15212520..bb28695a2 100644
--- a/blueprints/Depth to Video (ltx 2.0).json
+++ b/blueprints/Depth to Video (ltx 2.0).json
@@ -4233,7 +4233,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Video generation and editing/Depth to video"
+ "category": "Video generation and editing/Depth to video",
+ "description": "Generates video from depth maps using LTX-2, with optional synchronized audio."
},
{
"id": "38b60539-50a7-42f9-a5fe-bdeca26272e2",
@@ -5192,7 +5193,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
- }
+ },
+ "description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
}
]
},
diff --git a/blueprints/Edge-Preserving Blur.json b/blueprints/Edge-Preserving Blur.json
index 18012beb1..fbda9f126 100644
--- a/blueprints/Edge-Preserving Blur.json
+++ b/blueprints/Edge-Preserving Blur.json
@@ -450,9 +450,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Blur"
+ "category": "Image Tools/Blur",
+ "description": "Applies bilateral (edge-preserving) blur to soften images while retaining detail."
}
]
},
"extra": {}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Film Grain.json b/blueprints/Film Grain.json
index a680b3ece..3226ea9aa 100644
--- a/blueprints/Film Grain.json
+++ b/blueprints/Film Grain.json
@@ -580,8 +580,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Adds procedural film grain texture for a cinematic look via GPU fragment shader."
}
]
}
-}
+}
\ No newline at end of file
diff --git a/blueprints/First-Last-Frame to Video (LTX-2.3).json b/blueprints/First-Last-Frame to Video (LTX-2.3).json
index 8ec9ed61a..f509aefe0 100644
--- a/blueprints/First-Last-Frame to Video (LTX-2.3).json
+++ b/blueprints/First-Last-Frame to Video (LTX-2.3).json
@@ -3350,7 +3350,8 @@
}
],
"extra": {},
- "category": "Video generation and editing/First-Last-Frame to Video"
+ "category": "Video generation and editing/First-Last-Frame to Video",
+ "description": "Generates a video interpolating between first and last keyframes using LTX-2.3."
}
]
},
diff --git a/blueprints/Glow.json b/blueprints/Glow.json
index 1dafb2d35..2bbfdee51 100644
--- a/blueprints/Glow.json
+++ b/blueprints/Glow.json
@@ -575,8 +575,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Adds a glow/bloom effect around bright image areas via GPU fragment shader."
}
]
}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Hue and Saturation.json b/blueprints/Hue and Saturation.json
index 1a2df8937..cddf0154a 100644
--- a/blueprints/Hue and Saturation.json
+++ b/blueprints/Hue and Saturation.json
@@ -752,8 +752,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Adjusts hue, saturation, and lightness of an image using a real-time GPU fragment shader."
}
]
}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Image Blur.json b/blueprints/Image Blur.json
index 3c7a784b0..0ca8d9931 100644
--- a/blueprints/Image Blur.json
+++ b/blueprints/Image Blur.json
@@ -374,7 +374,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Blur"
+ "category": "Image Tools/Blur",
+ "description": "Applies Gaussian, Box, or Radial blur to soften images and create stylized depth or motion effects."
}
]
}
diff --git a/blueprints/Image Captioning (gemini).json b/blueprints/Image Captioning (gemini).json
index 98cfb8999..2fc5d6746 100644
--- a/blueprints/Image Captioning (gemini).json
+++ b/blueprints/Image Captioning (gemini).json
@@ -310,7 +310,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Text generation/Image Captioning"
+ "category": "Text generation/Image Captioning",
+ "description": "Generates descriptive captions for images using Google's Gemini multimodal LLM."
}
]
}
diff --git a/blueprints/Image Channels.json b/blueprints/Image Channels.json
index 9c7b675b2..b6fdff5be 100644
--- a/blueprints/Image Channels.json
+++ b/blueprints/Image Channels.json
@@ -315,8 +315,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Manipulates individual RGBA channels for masking, compositing, and channel effects."
}
]
}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Image Edit (FireRed Image Edit 1.1).json b/blueprints/Image Edit (FireRed Image Edit 1.1).json
index c34246ce6..14310353c 100644
--- a/blueprints/Image Edit (FireRed Image Edit 1.1).json
+++ b/blueprints/Image Edit (FireRed Image Edit 1.1).json
@@ -2138,7 +2138,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Edit image"
+ "category": "Image generation and editing/Edit image",
+ "description": "Edits images via text instructions using FireRed Image Edit 1.1, a diffusion-based instruction-following editing model."
}
]
},
diff --git a/blueprints/Image Edit (Flux.2 Klein 4B).json b/blueprints/Image Edit (Flux.2 Klein 4B).json
index 6f2f7dc01..7f6fa7a4b 100644
--- a/blueprints/Image Edit (Flux.2 Klein 4B).json
+++ b/blueprints/Image Edit (Flux.2 Klein 4B).json
@@ -1472,7 +1472,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Edit image"
+ "category": "Image generation and editing/Edit image",
+ "description": "Edits an input image via text instructions using FLUX.2 [klein] 4B."
},
{
"id": "6007e698-2ebd-4917-84d8-299b35d7b7ab",
@@ -1821,7 +1822,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
- }
+ },
+ "description": "Applies reference image conditioning for style/identity transfer (Flux.2 Klein 4B)."
}
]
},
@@ -1837,4 +1839,4 @@
}
},
"version": 0.4
-}
\ No newline at end of file
+}
diff --git a/blueprints/Image Edit (LongCat Image Edit).json b/blueprints/Image Edit (LongCat Image Edit).json
index 5b4eb18f0..de1c155a2 100644
--- a/blueprints/Image Edit (LongCat Image Edit).json
+++ b/blueprints/Image Edit (LongCat Image Edit).json
@@ -1417,7 +1417,8 @@
}
],
"extra": {},
- "category": "Image generation and editing/Edit image"
+ "category": "Image generation and editing/Edit image",
+ "description": "Edits images via text instructions using LongCat Image Edit, an instruction-following image editing diffusion model."
}
]
},
diff --git a/blueprints/Image Edit (Qwen 2511).json b/blueprints/Image Edit (Qwen 2511).json
index 582171fa0..1aa7e5765 100644
--- a/blueprints/Image Edit (Qwen 2511).json
+++ b/blueprints/Image Edit (Qwen 2511).json
@@ -132,7 +132,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Image Edit (Qwen 2511)",
+ "name": "Image Edit (Qwen 2511)",
"inputNode": {
"id": -10,
"bounding": [
@@ -1468,7 +1468,8 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
- "category": "Image generation and editing/Edit image"
+ "category": "Image generation and editing/Edit image",
+ "description": "Edits images via text instructions using Qwen-Image-Edit-2511 with improved character consistency and integrated LoRA."
}
]
},
@@ -1489,4 +1490,4 @@
}
},
"version": 0.4
-}
+}
\ No newline at end of file
diff --git a/blueprints/Image Inpainting (Flux.1 Fill Dev).json b/blueprints/Image Inpainting (Flux.1 Fill Dev).json
index d40d63594..c1326ed3d 100644
--- a/blueprints/Image Inpainting (Flux.1 Fill Dev).json
+++ b/blueprints/Image Inpainting (Flux.1 Fill Dev).json
@@ -1188,7 +1188,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Inpaint image"
+ "category": "Image generation and editing/Inpaint image",
+ "description": "Inpaints masked image regions using Flux.1 fill [dev], Black Forest Labs' inpainting/outpainting model."
}
]
},
@@ -1202,4 +1203,4 @@
},
"ue_links": []
}
-}
\ No newline at end of file
+}
diff --git a/blueprints/Image Inpainting (Qwen-image).json b/blueprints/Image Inpainting (Qwen-image).json
index 95b2909fa..a06d57e19 100644
--- a/blueprints/Image Inpainting (Qwen-image).json
+++ b/blueprints/Image Inpainting (Qwen-image).json
@@ -1548,7 +1548,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Inpaint image"
+ "category": "Image generation and editing/Inpaint image",
+ "description": "Inpaints masked regions using Qwen-Image, extending its multilingual text rendering to inpainting tasks."
},
{
"id": "56a1f603-fbd2-40ed-94ef-c9ecbd96aca8",
@@ -1907,7 +1908,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
- }
+ },
+ "description": "Expands and softens mask edges to reduce visible seams after image processing."
}
]
},
diff --git a/blueprints/Image Levels.json b/blueprints/Image Levels.json
index ef256a1aa..1a1b18932 100644
--- a/blueprints/Image Levels.json
+++ b/blueprints/Image Levels.json
@@ -742,9 +742,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Color adjust"
+ "category": "Image Tools/Color adjust",
+ "description": "Adjusts black point, white point, and gamma for tonal range control via GPU shader."
}
]
},
"extra": {}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Image Outpainting (Qwen-Image).json b/blueprints/Image Outpainting (Qwen-Image).json
index 218fdc775..6c07227c0 100644
--- a/blueprints/Image Outpainting (Qwen-Image).json
+++ b/blueprints/Image Outpainting (Qwen-Image).json
@@ -1919,7 +1919,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Outpaint image"
+ "category": "Image generation and editing/Outpaint image",
+ "description": "Outpaints beyond image boundaries using Qwen-Image's outpainting capabilities."
},
{
"id": "f93c215e-c393-460e-9534-ed2c3d8a652e",
@@ -2278,7 +2279,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
- }
+ },
+ "description": "Expands and softens mask edges to reduce visible seams after image processing."
},
{
"id": "2a4b2cc0-db37-4302-a067-da392f38f06b",
@@ -2733,7 +2735,8 @@
],
"extra": {
"workflowRendererVersion": "LG"
- }
+ },
+ "description": "Scales both image and mask together while preserving alignment for editing workflows."
}
]
},
diff --git a/blueprints/Image Upscale(Z-image-Turbo).json b/blueprints/Image Upscale(Z-image-Turbo).json
index 0d2b6e240..bd803a0b1 100644
--- a/blueprints/Image Upscale(Z-image-Turbo).json
+++ b/blueprints/Image Upscale(Z-image-Turbo).json
@@ -141,7 +141,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Image Upscale(Z-image-Turbo)",
+ "name": "Image Upscale (Z-image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@@ -1302,7 +1302,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Enhance"
+ "category": "Image generation and editing/Enhance",
+ "description": "Upscales images to higher resolution using Z-Image-Turbo."
}
]
},
diff --git a/blueprints/Image to Depth Map (Lotus).json b/blueprints/Image to Depth Map (Lotus).json
index 089f2cd42..12f10ba5b 100644
--- a/blueprints/Image to Depth Map (Lotus).json
+++ b/blueprints/Image to Depth Map (Lotus).json
@@ -99,7 +99,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Image to Depth Map (Lotus)",
+ "name": "Image to Depth Map (Lotus)",
"inputNode": {
"id": -10,
"bounding": [
@@ -948,7 +948,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Depth to image"
+ "category": "Image generation and editing/Depth to image",
+ "description": "Estimates a monocular depth map from an input image using the Lotus depth estimation model."
}
]
},
@@ -964,4 +965,4 @@
"workflowRendererVersion": "LG"
},
"version": 0.4
-}
+}
\ No newline at end of file
diff --git a/blueprints/Image to Layers(Qwen-Image-Layered).json b/blueprints/Image to Layers(Qwen-Image-Layered).json
index 8a525e7a5..7b44f0563 100644
--- a/blueprints/Image to Layers(Qwen-Image-Layered).json
+++ b/blueprints/Image to Layers(Qwen-Image-Layered).json
@@ -1586,7 +1586,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Image to layers"
+ "category": "Image generation and editing/Image to layers",
+ "description": "Decomposes an image into variable-resolution RGBA layers for independent editing using Qwen-Image-Layered."
}
]
},
diff --git a/blueprints/Image to Model (Hunyuan3d 2.1).json b/blueprints/Image to Model (Hunyuan3d 2.1).json
index 4705603a8..ee5552656 100644
--- a/blueprints/Image to Model (Hunyuan3d 2.1).json
+++ b/blueprints/Image to Model (Hunyuan3d 2.1).json
@@ -72,7 +72,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Image to Model (Hunyuan3d 2.1)",
+ "name": "Image to 3D Model (Hunyuan3d 2.1)",
"inputNode": {
"id": -10,
"bounding": [
@@ -765,7 +765,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "3D/Image to 3D Model"
+ "category": "3D/Image to 3D Model",
+ "description": "Generates 3D mesh models from a single input image using Hunyuan3D 2.0/2.1."
}
]
},
diff --git a/blueprints/Image to Video (LTX-2.3).json b/blueprints/Image to Video (LTX-2.3).json
index 86a601130..3db524ea0 100644
--- a/blueprints/Image to Video (LTX-2.3).json
+++ b/blueprints/Image to Video (LTX-2.3).json
@@ -4223,7 +4223,8 @@
"extra": {
"workflowRendererVersion": "Vue-corrected"
},
- "category": "Video generation and editing/Image to video"
+ "category": "Video generation and editing/Image to video",
+ "description": "Generates video from a single input image using LTX-2.3."
}
]
},
diff --git a/blueprints/Image to Video (Wan 2.2).json b/blueprints/Image to Video (Wan 2.2).json
index a8dafd3c9..3510aad18 100644
--- a/blueprints/Image to Video (Wan 2.2).json
+++ b/blueprints/Image to Video (Wan 2.2).json
@@ -206,7 +206,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Image to Video (Wan 2.2)",
+ "name": "Image to Video (Wan 2.2)",
"inputNode": {
"id": -10,
"bounding": [
@@ -2027,7 +2027,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Video generation and editing/Image to video"
+ "category": "Video generation and editing/Image to video",
+ "description": "Generates video from an image and text prompt using Wan 2.2, supporting T2V and I2V."
}
]
},
diff --git a/blueprints/Pose to Image (Z-Image-Turbo).json b/blueprints/Pose to Image (Z-Image-Turbo).json
index a55410ba4..5c2749efe 100644
--- a/blueprints/Pose to Image (Z-Image-Turbo).json
+++ b/blueprints/Pose to Image (Z-Image-Turbo).json
@@ -134,7 +134,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Pose to Image (Z-Image-Turbo)",
+ "name": "Pose to Image (Z-Image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@@ -1298,7 +1298,8 @@
"VHS_MetadataImage": true,
"VHS_KeepIntermediate": true
},
- "category": "Image generation and editing/Pose to image"
+ "category": "Image generation and editing/Pose to image",
+ "description": "Generates an image from pose keypoints using Z-Image-Turbo with text conditioning."
}
]
},
@@ -1319,4 +1320,4 @@
}
},
"version": 0.4
-}
+}
\ No newline at end of file
diff --git a/blueprints/Pose to Video (LTX 2.0).json b/blueprints/Pose to Video (LTX 2.0).json
index 580900bc0..1ce49351a 100644
--- a/blueprints/Pose to Video (LTX 2.0).json
+++ b/blueprints/Pose to Video (LTX 2.0).json
@@ -3870,7 +3870,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Video generation and editing/Pose to video"
+ "category": "Video generation and editing/Pose to video",
+ "description": "Generates video from pose reference frames using LTX-2, with optional synchronized audio."
}
]
},
diff --git a/blueprints/Prompt Enhance.json b/blueprints/Prompt Enhance.json
index 5e57548ff..e260b1203 100644
--- a/blueprints/Prompt Enhance.json
+++ b/blueprints/Prompt Enhance.json
@@ -270,9 +270,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Text generation/Prompt enhance"
+ "category": "Text generation/Prompt enhance",
+ "description": "Expands short text prompts into detailed descriptions using a text generation model for better generation quality."
}
]
},
"extra": {}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Sharpen.json b/blueprints/Sharpen.json
index f332400fd..3c4099c6b 100644
--- a/blueprints/Sharpen.json
+++ b/blueprints/Sharpen.json
@@ -302,8 +302,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Sharpen"
+ "category": "Image Tools/Sharpen",
+ "description": "Sharpens image details using a GPU fragment shader for enhanced clarity."
}
]
}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Text to Audio (ACE-Step 1.5).json b/blueprints/Text to Audio (ACE-Step 1.5).json
index 206cf16be..5b8b8626f 100644
--- a/blueprints/Text to Audio (ACE-Step 1.5).json
+++ b/blueprints/Text to Audio (ACE-Step 1.5).json
@@ -222,7 +222,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Text to Audio (ACE-Step 1.5)",
+ "name": "Text to Audio (ACE-Step 1.5)",
"inputNode": {
"id": -10,
"bounding": [
@@ -1502,7 +1502,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Audio/Music generation"
+ "category": "Audio/Music generation",
+ "description": "Generates audio/music from text prompts using ACE-Step 1.5, a diffusion-based audio generation model."
}
]
},
@@ -1518,4 +1519,4 @@
}
},
"version": 0.4
-}
+}
\ No newline at end of file
diff --git a/blueprints/Text to Image (Flux.1 Dev).json b/blueprints/Text to Image (Flux.1 Dev).json
index 04c3cb95a..45f68f508 100644
--- a/blueprints/Text to Image (Flux.1 Dev).json
+++ b/blueprints/Text to Image (Flux.1 Dev).json
@@ -1029,7 +1029,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Text to image"
+ "category": "Image generation and editing/Text to image",
+ "description": "Generates images from text prompts using Flux.1 [dev], Black Forest Labs' 12B diffusion model."
}
]
},
@@ -1043,4 +1044,4 @@
},
"ue_links": []
}
-}
\ No newline at end of file
+}
diff --git a/blueprints/Text to Image (Flux.1 Krea Dev).json b/blueprints/Text to Image (Flux.1 Krea Dev).json
index fe4db1cfc..30a78dca1 100644
--- a/blueprints/Text to Image (Flux.1 Krea Dev).json
+++ b/blueprints/Text to Image (Flux.1 Krea Dev).json
@@ -1023,7 +1023,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Text to image"
+ "category": "Image generation and editing/Text to image",
+ "description": "Generates images from text prompts using Flux.1 Krea Dev, a Black Forest Labs × Krea collaboration variant."
}
]
},
@@ -1037,4 +1038,4 @@
},
"ue_links": []
}
-}
\ No newline at end of file
+}
diff --git a/blueprints/Text to Image (NetaYume Lumina).json b/blueprints/Text to Image (NetaYume Lumina).json
index 394ad1608..9e11b7a86 100644
--- a/blueprints/Text to Image (NetaYume Lumina).json
+++ b/blueprints/Text to Image (NetaYume Lumina).json
@@ -1104,7 +1104,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Text to image"
+ "category": "Image generation and editing/Text to image",
+ "description": "Generates images from text prompts using NetaYume Lumina, fine-tuned from Neta Lumina for anime-style and illustration generation."
},
{
"id": "a07fdf06-1bda-4dac-bdbd-63ee8ebca1c9",
@@ -1458,11 +1459,12 @@
],
"extra": {
"workflowRendererVersion": "LG"
- }
+ },
+ "description": "Encodes a negative text prompt via CLIP for classifier-free guidance in anime-style generation (NetaYume Lumina)."
}
]
},
"extra": {
"ue_links": []
}
-}
\ No newline at end of file
+}
diff --git a/blueprints/Text to Image (Qwen-Image 2512).json b/blueprints/Text to Image (Qwen-Image 2512).json
index f52ea2ef2..09612be8b 100644
--- a/blueprints/Text to Image (Qwen-Image 2512).json
+++ b/blueprints/Text to Image (Qwen-Image 2512).json
@@ -1941,7 +1941,8 @@
"extra": {
"workflowRendererVersion": "Vue-corrected"
},
- "category": "Image generation and editing/Text to image"
+ "category": "Image generation and editing/Text to image",
+ "description": "Generates images from text prompts using Qwen-Image-2512, with enhanced human realism and finer natural detail over the base version."
}
]
},
diff --git a/blueprints/Text to Image (Qwen-Image).json b/blueprints/Text to Image (Qwen-Image).json
index 70b4b44b3..e78d5a962 100644
--- a/blueprints/Text to Image (Qwen-Image).json
+++ b/blueprints/Text to Image (Qwen-Image).json
@@ -1873,7 +1873,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Text to image"
+ "category": "Image generation and editing/Text to image",
+ "description": "Generates images from text prompts using Qwen-Image, Alibaba's 20B MMDiT model with excellent multilingual text rendering."
}
]
},
diff --git a/blueprints/Text to Image (Z-Image-Turbo).json b/blueprints/Text to Image (Z-Image-Turbo).json
index 6aa80e327..6975151ea 100644
--- a/blueprints/Text to Image (Z-Image-Turbo).json
+++ b/blueprints/Text to Image (Z-Image-Turbo).json
@@ -149,7 +149,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Text to Image (Z-Image-Turbo)",
+ "name": "Text to Image (Z-Image-Turbo)",
"inputNode": {
"id": -10,
"bounding": [
@@ -1054,7 +1054,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image generation and editing/Text to image"
+ "category": "Image generation and editing/Text to image",
+ "description": "Generates images from text prompts using Z-Image-Turbo, Alibaba's distilled 6B DiT model."
}
]
},
@@ -1075,4 +1076,4 @@
}
},
"version": 0.4
-}
+}
\ No newline at end of file
diff --git a/blueprints/Text to Video (LTX-2.3).json b/blueprints/Text to Video (LTX-2.3).json
index ff9bc6ccf..f44a216dd 100644
--- a/blueprints/Text to Video (LTX-2.3).json
+++ b/blueprints/Text to Video (LTX-2.3).json
@@ -4286,7 +4286,8 @@
"extra": {
"workflowRendererVersion": "Vue-corrected"
},
- "category": "Video generation and editing/Text to video"
+ "category": "Video generation and editing/Text to video",
+ "description": "Generates video from text prompts using LTX-2.3, Lightricks' video diffusion model."
}
]
},
diff --git a/blueprints/Text to Video (Wan 2.2).json b/blueprints/Text to Video (Wan 2.2).json
index 0ce485b67..a264a490d 100644
--- a/blueprints/Text to Video (Wan 2.2).json
+++ b/blueprints/Text to Video (Wan 2.2).json
@@ -1572,7 +1572,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Video generation and editing/Text to video"
+ "category": "Video generation and editing/Text to video",
+ "description": "Generates video from text prompts using Wan2.2, Alibaba's diffusion video model."
}
]
},
@@ -1586,4 +1587,4 @@
"VHS_KeepIntermediate": true
},
"version": 0.4
-}
+}
\ No newline at end of file
diff --git a/blueprints/Unsharp Mask.json b/blueprints/Unsharp Mask.json
index 137acaa43..79a4c954f 100644
--- a/blueprints/Unsharp Mask.json
+++ b/blueprints/Unsharp Mask.json
@@ -434,8 +434,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Image Tools/Sharpen"
+ "category": "Image Tools/Sharpen",
+ "description": "Enhances edge contrast via unsharp masking for a sharper image appearance."
}
]
}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Video Captioning (Gemini).json b/blueprints/Video Captioning (Gemini).json
index ea6dc8bee..7642b23c1 100644
--- a/blueprints/Video Captioning (Gemini).json
+++ b/blueprints/Video Captioning (Gemini).json
@@ -307,7 +307,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Text generation/Video Captioning"
+ "category": "Text generation/Video Captioning",
+ "description": "Generates descriptive captions for video input using Google's Gemini multimodal LLM."
}
]
}
diff --git a/blueprints/Video Inpaint(Wan2.1 VACE).json b/blueprints/Video Inpaint(Wan2.1 VACE).json
index f404e6773..a658be5f8 100644
--- a/blueprints/Video Inpaint(Wan2.1 VACE).json
+++ b/blueprints/Video Inpaint(Wan2.1 VACE).json
@@ -165,7 +165,7 @@
},
"revision": 0,
"config": {},
- "name": "local-Video Inpaint(Wan2.1 VACE)",
+ "name": "Video Inpaint (Wan 2.1 VACE)",
"inputNode": {
"id": -10,
"bounding": [
@@ -2368,7 +2368,8 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Video generation and editing/Inpaint video"
+ "category": "Video generation and editing/Inpaint video",
+ "description": "Inpaints masked regions in video frames using Wan 2.1 VACE."
}
]
},
diff --git a/blueprints/Video Stitch.json b/blueprints/Video Stitch.json
index 020896d78..6eb0f0bbf 100644
--- a/blueprints/Video Stitch.json
+++ b/blueprints/Video Stitch.json
@@ -584,8 +584,9 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Video Tools/Stitch videos"
+ "category": "Video Tools/Stitch videos",
+ "description": "Stitches multiple video clips into a single sequential video file."
}
]
}
-}
+}
\ No newline at end of file
diff --git a/blueprints/Video Upscale(GAN x4).json b/blueprints/Video Upscale(GAN x4).json
index b61dc88d7..73476e36b 100644
--- a/blueprints/Video Upscale(GAN x4).json
+++ b/blueprints/Video Upscale(GAN x4).json
@@ -412,9 +412,10 @@
"extra": {
"workflowRendererVersion": "LG"
},
- "category": "Video generation and editing/Enhance video"
+ "category": "Video generation and editing/Enhance video",
+ "description": "Upscales video to 4× resolution using a GAN-based upscaling model."
}
]
},
"extra": {}
-}
+}
\ No newline at end of file
diff --git a/comfy/background_removal/birefnet.json b/comfy/background_removal/birefnet.json
new file mode 100644
index 000000000..f0960af39
--- /dev/null
+++ b/comfy/background_removal/birefnet.json
@@ -0,0 +1,7 @@
+{
+ "model_type": "birefnet",
+ "image_std": [1.0, 1.0, 1.0],
+ "image_mean": [0.0, 0.0, 0.0],
+ "image_size": 1024,
+ "resize_to_original": true
+}
diff --git a/comfy/background_removal/birefnet.py b/comfy/background_removal/birefnet.py
new file mode 100644
index 000000000..df54b2b90
--- /dev/null
+++ b/comfy/background_removal/birefnet.py
@@ -0,0 +1,689 @@
+import torch
+import comfy.ops
+import numpy as np
+import torch.nn as nn
+from functools import partial
+import torch.nn.functional as F
+from torchvision.ops import deform_conv2d
+from comfy.ldm.modules.attention import optimized_attention_for_device
+
+CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
+ super().__init__()
+
+ self.dim = dim
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
+ self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
+ self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ k, v = kv[0], kv[1]
+
+ x = optimized_attention(
+ q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
+ ).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+
+ return x
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
+ self.act = nn.GELU()
+ self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ return x
+
+
+def window_partition(x, window_size):
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
+
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
+ self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+
+ self.norm1 = norm_layer(dim, device=device, dtype=dtype)
+ self.attn = WindowAttention(
+ dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
+
+ self.norm2 = norm_layer(dim, device=device, dtype=dtype)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ B, L, C = x.shape
+ H, W = self.H, self.W
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ x_windows = window_partition(shifted_x, self.window_size)
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
+
+ attn_windows = self.attn(x_windows, mask=attn_mask)
+
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ x = shortcut + x
+ x = x + self.mlp(self.norm2(x))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ def __init__(self, dim, device=None, dtype=None, operations=None):
+ super().__init__()
+ self.dim = dim
+ self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
+ self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
+
+ def forward(self, x, H, W):
+ B, L, C = x.shape
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ def __init__(self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ device=None, dtype=None, operations=None):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ norm_layer=norm_layer,
+ device=device, dtype=dtype, operations=operations)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size)
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ x = blk(x, attn_mask)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
+ super().__init__()
+ patch_size = (patch_size, patch_size)
+ self.patch_size = patch_size
+
+ self.in_channels = in_channels
+ self.embed_dim = embed_dim
+
+ self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+class SwinTransformer(nn.Module):
+ def __init__(self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_channels=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ device=None, dtype=None, operations=None):
+ super().__init__()
+
+ norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
+ device=device, dtype=dtype, operations=operations,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ device=device, dtype=dtype, operations=operations)
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+
+ def forward(self, x):
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+
+ outs = []
+ x = x.flatten(2).transpose(1, 2)
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+
+ return tuple(outs)
+
+class DeformableConv2d(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False, device=None, dtype=None, operations=None):
+
+ super(DeformableConv2d, self).__init__()
+
+ kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
+ self.stride = stride if type(stride) is tuple else (stride, stride)
+ self.padding = padding
+
+ self.offset_conv = operations.Conv2d(in_channels,
+ 2 * kernel_size[0] * kernel_size[1],
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=self.padding,
+ bias=True, device=device, dtype=dtype)
+
+ self.modulator_conv = operations.Conv2d(in_channels,
+ 1 * kernel_size[0] * kernel_size[1],
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=self.padding,
+ bias=True, device=device, dtype=dtype)
+
+ self.regular_conv = operations.Conv2d(in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=self.padding,
+ bias=bias, device=device, dtype=dtype)
+
+ def forward(self, x):
+ offset = self.offset_conv(x)
+ modulator = 2. * torch.sigmoid(self.modulator_conv(x))
+ weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
+
+ x = deform_conv2d(
+ input=x,
+ offset=offset,
+ weight=weight,
+ bias=None,
+ padding=self.padding,
+ mask=modulator,
+ stride=self.stride,
+ )
+ comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
+ return x
+
+class BasicDecBlk(nn.Module):
+ def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
+ super(BasicDecBlk, self).__init__()
+ inter_channels = 64
+ self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
+ self.relu_in = nn.ReLU(inplace=True)
+ self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
+ self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
+ self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
+ self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ x = self.bn_in(x)
+ x = self.relu_in(x)
+ x = self.dec_att(x)
+ x = self.conv_out(x)
+ x = self.bn_out(x)
+ return x
+
+
+class BasicLatBlk(nn.Module):
+ def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
+ super(BasicLatBlk, self).__init__()
+ self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class _ASPPModuleDeformable(nn.Module):
+ def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
+ super(_ASPPModuleDeformable, self).__init__()
+ self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
+ stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
+ self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.atrous_conv(x)
+ x = self.bn(x)
+
+ return self.relu(x)
+
+
+class ASPPDeformable(nn.Module):
+ def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
+ super(ASPPDeformable, self).__init__()
+ self.down_scale = 1
+ if out_channels is None:
+ out_channels = in_channels
+ self.in_channelster = 256 // self.down_scale
+
+ self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
+ self.aspp_deforms = nn.ModuleList([
+ _ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
+ for conv_size in parallel_block_sizes
+ ])
+
+ self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
+ operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
+ operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
+ nn.ReLU(inplace=True))
+ self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
+ self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x1 = self.aspp1(x)
+ x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
+ x5 = self.global_avg_pool(x)
+ x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
+ x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ return x
+
+class BiRefNet(nn.Module):
+ def __init__(self, config=None, dtype=None, device=None, operations=None):
+ super(BiRefNet, self).__init__()
+ self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
+
+ channels = [1536, 768, 384, 192]
+ channels = [c * 2 for c in channels]
+ self.cxt = channels[1:][::-1][-3:]
+ self.squeeze_module = nn.Sequential(*[
+ BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
+ for _ in range(1)
+ ])
+
+ self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
+
+ def forward_enc(self, x):
+ x1, x2, x3, x4 = self.bb(x)
+ B, C, H, W = x.shape
+ x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
+ x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
+ x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
+ x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
+ x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
+ x4 = torch.cat(
+ (
+ *[
+ F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
+ F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
+ F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
+ ][-len(CXT):],
+ x4
+ ),
+ dim=1
+ )
+ return (x1, x2, x3, x4)
+
+ def forward_ori(self, x):
+ (x1, x2, x3, x4) = self.forward_enc(x)
+ x4 = self.squeeze_module(x4)
+ features = [x, x1, x2, x3, x4]
+ scaled_preds = self.decoder(features)
+ return scaled_preds
+
+ def forward(self, pixel_values, intermediate_output=None):
+ scaled_preds = self.forward_ori(pixel_values)
+ return scaled_preds
+
+
+class Decoder(nn.Module):
+ def __init__(self, channels, device, dtype, operations):
+ super(Decoder, self).__init__()
+ # factory kwargs
+ fk = {"device":device, "dtype":dtype, "operations":operations}
+ DecoderBlock = partial(BasicDecBlk, **fk)
+ LateralBlock = partial(BasicLatBlk, **fk)
+ DBlock = partial(SimpleConvs, **fk)
+
+ self.split = True
+ N_dec_ipt = 64
+ ic = 64
+ ipt_cha_opt = 1
+ self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
+ self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
+ self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
+ self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
+ self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
+
+ self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
+ self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
+ self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
+ self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
+
+ fk = {"device":device, "dtype":dtype}
+
+ self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
+
+ self.lateral_block4 = LateralBlock(channels[1], channels[1])
+ self.lateral_block3 = LateralBlock(channels[2], channels[2])
+ self.lateral_block2 = LateralBlock(channels[3], channels[3])
+
+ self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
+ self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
+ self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
+
+ _N = 16
+
+ self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
+ self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
+ self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
+
+ [setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
+ [setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
+
+ def get_patches_batch(self, x, p):
+ _size_h, _size_w = p.shape[2:]
+ patches_batch = []
+ for idx in range(x.shape[0]):
+ columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
+ patches_x = []
+ for column_x in columns_x:
+ patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
+ patch_sample = torch.cat(patches_x, dim=1)
+ patches_batch.append(patch_sample)
+ return torch.cat(patches_batch, dim=0)
+
+ def forward(self, features):
+ x, x1, x2, x3, x4 = features
+
+ patches_batch = self.get_patches_batch(x, x4) if self.split else x
+ x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
+ p4 = self.decoder_block4(x4)
+ p4_gdt = self.gdt_convs_4(p4)
+ gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
+ p4 = p4 * gdt_attn_4
+ _p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
+ _p3 = _p4 + self.lateral_block4(x3)
+
+ patches_batch = self.get_patches_batch(x, _p3) if self.split else x
+ _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
+ p3 = self.decoder_block3(_p3)
+
+ p3_gdt = self.gdt_convs_3(p3)
+ gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
+ p3 = p3 * gdt_attn_3
+ _p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
+ _p2 = _p3 + self.lateral_block3(x2)
+
+ patches_batch = self.get_patches_batch(x, _p2) if self.split else x
+ _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
+ p2 = self.decoder_block2(_p2)
+
+ p2_gdt = self.gdt_convs_2(p2)
+ gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
+ p2 = p2 * gdt_attn_2
+
+ _p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
+ _p1 = _p2 + self.lateral_block2(x1)
+
+ patches_batch = self.get_patches_batch(x, _p1) if self.split else x
+ _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
+ _p1 = self.decoder_block1(_p1)
+ _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
+
+ patches_batch = self.get_patches_batch(x, _p1) if self.split else x
+ _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
+ p1_out = self.conv_out1(_p1)
+ return p1_out
+
+
+class SimpleConvs(nn.Module):
+ def __init__(
+ self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
+ ) -> None:
+ super().__init__()
+ self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
+ self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
+
+ def forward(self, x):
+ return self.conv_out(self.conv1(x))
diff --git a/comfy/bg_removal_model.py b/comfy/bg_removal_model.py
new file mode 100644
index 000000000..7877afd7f
--- /dev/null
+++ b/comfy/bg_removal_model.py
@@ -0,0 +1,78 @@
+from .utils import load_torch_file
+import os
+import json
+import torch
+import logging
+
+import comfy.ops
+import comfy.model_patcher
+import comfy.model_management
+import comfy.clip_model
+import comfy.background_removal.birefnet
+
+BG_REMOVAL_MODELS = {
+ "birefnet": comfy.background_removal.birefnet.BiRefNet
+}
+
+class BackgroundRemovalModel():
+ def __init__(self, json_config):
+ with open(json_config) as f:
+ config = json.load(f)
+
+ self.image_size = config.get("image_size", 1024)
+ self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
+ self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
+ self.model_type = config.get("model_type", "birefnet")
+ self.config = config.copy()
+ model_class = BG_REMOVAL_MODELS.get(self.model_type)
+
+ self.load_device = comfy.model_management.text_encoder_device()
+ offload_device = comfy.model_management.text_encoder_offload_device()
+ self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
+ self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
+ self.model.eval()
+
+ self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
+
+ def load_sd(self, sd):
+ return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
+
+ def get_sd(self):
+ return self.model.state_dict()
+
+ def encode_image(self, image):
+ comfy.model_management.load_model_gpu(self.patcher)
+ H, W = image.shape[1], image.shape[2]
+ pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
+ out = self.model(pixel_values=pixel_values)
+ out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
+
+ mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
+ if mask.ndim == 3:
+ mask = mask.unsqueeze(0)
+ if mask.shape[1] != 1:
+ mask = mask.movedim(-1, 1)
+
+ return mask
+
+
+def load_background_removal_model(sd):
+ if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
+ json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
+ else:
+ return None
+
+ bg_model = BackgroundRemovalModel(json_config)
+ m, u = bg_model.load_sd(sd)
+ if len(m) > 0:
+ logging.warning("missing background removal: {}".format(m))
+ u = set(u)
+ keys = list(sd.keys())
+ for k in keys:
+ if k not in u:
+ sd.pop(k)
+ return bg_model
+
+def load(ckpt_path):
+ sd = load_torch_file(ckpt_path)
+ return load_background_removal_model(sd)
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index dbaadf723..9dadb0093 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -90,8 +90,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
-parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
+parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
@@ -238,6 +238,8 @@ database_default_path = os.path.abspath(
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
+parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
+parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
if comfy.options.args_parsing:
args = parser.parse_args()
diff --git a/comfy/context_windows.py b/comfy/context_windows.py
index cb44ee6e8..db57537a2 100644
--- a/comfy/context_windows.py
+++ b/comfy/context_windows.py
@@ -63,7 +63,11 @@ class IndexListContextWindow(ContextWindowABC):
dim = self.dim
if dim == 0 and full.shape[dim] == 1:
return full
- idx = tuple([slice(None)] * dim + [self.index_list])
+ indices = self.index_list
+ anchor_idx = getattr(self, 'causal_anchor_index', None)
+ if anchor_idx is not None and anchor_idx >= 0:
+ indices = [anchor_idx] + list(indices)
+ idx = tuple([slice(None)] * dim + [indices])
window = full[idx]
if retain_index_list:
idx = tuple([slice(None)] * dim + [retain_index_list])
@@ -113,7 +117,14 @@ def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, d
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
if temporal_offset > 0:
- indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
+ anchor_idx = getattr(window, 'causal_anchor_index', None)
+ if anchor_idx is not None and anchor_idx >= 0:
+ # anchor occupies one of the no-cond positions, so skip one fewer from window.index_list
+ skip_count = temporal_offset - 1
+ else:
+ skip_count = temporal_offset
+
+ indices = [i - temporal_offset for i in window.index_list[skip_count:]]
indices = [i for i in indices if 0 <= i]
else:
indices = list(window.index_list)
@@ -150,7 +161,8 @@ class ContextFuseMethod:
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
class IndexListContextHandler(ContextHandlerABC):
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
- closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
+ closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False,
+ causal_window_fix: bool=True):
self.context_schedule = context_schedule
self.fuse_method = fuse_method
self.context_length = context_length
@@ -162,6 +174,7 @@ class IndexListContextHandler(ContextHandlerABC):
self.freenoise = freenoise
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
self.split_conds_to_windows = split_conds_to_windows
+ self.causal_window_fix = causal_window_fix
self.callbacks = {}
@@ -318,6 +331,14 @@ class IndexListContextHandler(ContextHandlerABC):
# allow processing to end between context window executions for faster Cancel
comfy.model_management.throw_exception_if_processing_interrupted()
+ # causal_window_fix: prepend a pre-window frame that will be stripped post-forward
+ anchor_applied = False
+ if self.causal_window_fix:
+ anchor_idx = window.index_list[0] - 1
+ if 0 <= anchor_idx < x_in.size(self.dim):
+ window.causal_anchor_index = anchor_idx
+ anchor_applied = True
+
for callback in comfy.patcher_extension.get_all_callbacks(IndexListCallbacks.EVALUATE_CONTEXT_WINDOWS, self.callbacks):
callback(self, model, x_in, conds, timestep, model_options, window_idx, window, model_options, device, first_device)
@@ -332,6 +353,12 @@ class IndexListContextHandler(ContextHandlerABC):
if device is not None:
for i in range(len(sub_conds_out)):
sub_conds_out[i] = sub_conds_out[i].to(x_in.device)
+
+ # strip causal_window_fix anchor if applied
+ if anchor_applied:
+ for i in range(len(sub_conds_out)):
+ sub_conds_out[i] = sub_conds_out[i].narrow(self.dim, 1, sub_conds_out[i].shape[self.dim] - 1)
+
results.append(ContextResults(window_idx, sub_conds_out, sub_conds, window))
return results
diff --git a/comfy/deploy_environment.py b/comfy/deploy_environment.py
new file mode 100644
index 000000000..8c99a3584
--- /dev/null
+++ b/comfy/deploy_environment.py
@@ -0,0 +1,34 @@
+import functools
+import logging
+import os
+
+logger = logging.getLogger(__name__)
+
+_DEFAULT_DEPLOY_ENV = "local-git"
+_ENV_FILENAME = ".comfy_environment"
+
+# Resolve the ComfyUI install directory (the parent of this `comfy/` package).
+# We deliberately avoid `folder_paths.base_path` here because that is overridden
+# by the `--base-directory` CLI arg to a user-supplied path, whereas the
+# `.comfy_environment` marker is written by launchers/installers next to the
+# ComfyUI install itself.
+_COMFY_INSTALL_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+
+
+@functools.cache
+def get_deploy_environment() -> str:
+ env_file = os.path.join(_COMFY_INSTALL_DIR, _ENV_FILENAME)
+ try:
+ with open(env_file, encoding="utf-8") as f:
+ # Cap the read so a malformed or maliciously crafted file (e.g.
+ # a single huge line with no newline) can't blow up memory.
+ first_line = f.readline(128).strip()
+ value = "".join(c for c in first_line if 32 <= ord(c) < 127)
+ if value:
+ return value
+ except FileNotFoundError:
+ pass
+ except Exception as e:
+ logger.error("Failed to read %s: %s", env_file, e)
+
+ return _DEFAULT_DEPLOY_ENV
diff --git a/comfy/hooks.py b/comfy/hooks.py
index 1a76c7ba4..5458fc3d8 100644
--- a/comfy/hooks.py
+++ b/comfy/hooks.py
@@ -93,7 +93,7 @@ class Hook:
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
- '''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
+ '''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index 6978eb717..c53ac4b2b 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -1810,3 +1810,119 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
+
+
+@torch.no_grad()
+def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
+ num_frame_per_block=1):
+ """
+ Autoregressive video sampler: block-by-block denoising with KV cache
+ and flow-match re-noising for Causal Forcing / Self-Forcing models.
+
+ Requires a Causal-WAN compatible model (diffusion_model must expose
+ init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
+
+ All AR-loop parameters are passed via the SamplerARVideo node, not read
+ from the checkpoint or transformer_options.
+ """
+ extra_args = {} if extra_args is None else extra_args
+ model_options = extra_args.get("model_options", {})
+ transformer_options = model_options.get("transformer_options", {})
+
+ if x.ndim != 5:
+ raise ValueError(
+ f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
+ "This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
+ )
+
+ inner_model = model.inner_model.inner_model
+ causal_model = inner_model.diffusion_model
+
+ if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
+ raise TypeError(
+ "ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
+ "exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
+ "does not support this interface — choose a different sampler."
+ )
+
+ seed = extra_args.get("seed", 0)
+
+ bs, c, lat_t, lat_h, lat_w = x.shape
+ frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
+ num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
+ device = x.device
+ model_dtype = inner_model.get_dtype()
+
+ kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
+ crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
+
+ output = torch.zeros_like(x)
+ s_in = x.new_ones([x.shape[0]])
+ current_start_frame = 0
+
+ # I2V: seed KV cache with the initial image latent before the denoising loop
+ initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None)
+ if initial_latent is not None:
+ initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
+ n_init = initial_latent.shape[2]
+ output[:, :, :n_init] = initial_latent
+
+ ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
+ transformer_options["ar_state"] = ar_state
+ zero_sigma = sigmas.new_zeros([1])
+ _ = model(initial_latent, zero_sigma * s_in, **extra_args)
+
+ current_start_frame = n_init
+ remaining = lat_t - n_init
+ num_blocks = -(-remaining // num_frame_per_block)
+
+ num_sigma_steps = len(sigmas) - 1
+ total_real_steps = num_blocks * num_sigma_steps
+ step_count = 0
+
+ try:
+ for block_idx in trange(num_blocks, disable=disable):
+ bf = min(num_frame_per_block, lat_t - current_start_frame)
+ fs, fe = current_start_frame, current_start_frame + bf
+ noisy_input = x[:, :, fs:fe]
+
+ ar_state = {
+ "start_frame": current_start_frame,
+ "kv_caches": kv_caches,
+ "crossattn_caches": crossattn_caches,
+ }
+ transformer_options["ar_state"] = ar_state
+
+ for i in range(num_sigma_steps):
+ denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
+
+ if callback is not None:
+ scaled_i = step_count * num_sigma_steps // total_real_steps
+ callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
+ "sigma_hat": sigmas[i], "denoised": denoised})
+
+ if sigmas[i + 1] == 0:
+ noisy_input = denoised
+ else:
+ sigma_next = sigmas[i + 1]
+ torch.manual_seed(seed + block_idx * 1000 + i)
+ fresh_noise = torch.randn_like(denoised)
+ noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
+
+ for cache in kv_caches:
+ cache["end"] -= bf * frame_seq_len
+
+ step_count += 1
+
+ output[:, :, fs:fe] = noisy_input
+
+ for cache in kv_caches:
+ cache["end"] -= bf * frame_seq_len
+ zero_sigma = sigmas.new_zeros([1])
+ _ = model(noisy_input, zero_sigma * s_in, **extra_args)
+
+ current_start_frame += bf
+ finally:
+ transformer_options.pop("ar_state", None)
+
+ return output
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index 6a57bca1c..91bebed3d 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -9,6 +9,7 @@ class LatentFormat:
latent_rgb_factors_reshape = None
taesd_decoder_name = None
spacial_downscale_ratio = 8
+ temporal_downscale_ratio = 1
def process_in(self, latent):
return latent * self.scale_factor
@@ -224,6 +225,7 @@ class Flux2(LatentFormat):
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
+ self.taesd_decoder_name = "taef2_decoder"
def process_in(self, latent):
return latent
@@ -234,6 +236,7 @@ class Flux2(LatentFormat):
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
+ temporal_downscale_ratio = 6
def __init__(self):
self.scale_factor = 1.0
@@ -277,6 +280,7 @@ class LTXV(LatentFormat):
latent_channels = 128
latent_dimensions = 3
spacial_downscale_ratio = 32
+ temporal_downscale_ratio = 8
def __init__(self):
self.latent_rgb_factors = [
@@ -420,6 +424,7 @@ class LTXAV(LTXV):
class HunyuanVideo(LatentFormat):
latent_channels = 16
latent_dimensions = 3
+ temporal_downscale_ratio = 4
scale_factor = 0.476986
latent_rgb_factors = [
[-0.0395, -0.0331, 0.0445],
@@ -446,6 +451,7 @@ class HunyuanVideo(LatentFormat):
class Cosmos1CV8x8x8(LatentFormat):
latent_channels = 16
latent_dimensions = 3
+ temporal_downscale_ratio = 8
latent_rgb_factors = [
[ 0.1817, 0.2284, 0.2423],
@@ -471,6 +477,7 @@ class Cosmos1CV8x8x8(LatentFormat):
class Wan21(LatentFormat):
latent_channels = 16
latent_dimensions = 3
+ temporal_downscale_ratio = 4
latent_rgb_factors = [
[-0.1299, -0.1692, 0.2932],
@@ -733,6 +740,7 @@ class HunyuanVideo15(LatentFormat):
latent_channels = 32
latent_dimensions = 3
spacial_downscale_ratio = 16
+ temporal_downscale_ratio = 4
scale_factor = 1.03682
taesd_decoder_name = "lighttaehy1_5"
@@ -783,3 +791,29 @@ class ZImagePixelSpace(ChromaRadiance):
No VAE encoding/decoding — the model operates directly on RGB pixels.
"""
pass
+
+class CogVideoX(LatentFormat):
+ """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
+
+ scale_factor matches the vae/config.json scaling_factor for the 2b variant.
+ The 5b-class checkpoints (CogVideoX-5b, CogVideoX-1.5-5B, CogVideoX-Fun-V1.5-*)
+ use a different value; see CogVideoX1_5 below.
+ """
+ latent_channels = 16
+ latent_dimensions = 3
+ temporal_downscale_ratio = 4
+
+ def __init__(self):
+ self.scale_factor = 1.15258426
+
+
+class CogVideoX1_5(CogVideoX):
+ """Latent format for 5b-class CogVideoX checkpoints.
+
+ Covers THUDM/CogVideoX-5b, THUDM/CogVideoX-1.5-5B, and the CogVideoX-Fun
+ V1.5-5b family (including VOID inpainting). All of these have
+ scaling_factor=0.7 in their vae/config.json. Auto-selected in
+ supported_models.CogVideoX_T2V based on transformer hidden dim.
+ """
+ def __init__(self):
+ self.scale_factor = 0.7
diff --git a/comfy/ldm/cogvideo/__init__.py b/comfy/ldm/cogvideo/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py
new file mode 100644
index 000000000..fb475ed53
--- /dev/null
+++ b/comfy/ldm/cogvideo/model.py
@@ -0,0 +1,573 @@
+# CogVideoX 3D Transformer - ported to ComfyUI native ops
+# Architecture reference: diffusers CogVideoXTransformer3DModel
+# Style reference: comfy/ldm/wan/model.py
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.patcher_extension
+import comfy.ldm.common_dit
+
+
+def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0):
+ """Returns (cos, sin) each with shape [seq_len, dim].
+
+ Frequencies are computed at dim//2 resolution then repeat_interleaved
+ to full dim, matching CogVideoX's interleaved (real, imag) pair format.
+ """
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim))
+ angles = torch.outer(pos.float(), freqs.float())
+ cos = angles.cos().repeat_interleave(2, dim=-1).float()
+ sin = angles.sin().repeat_interleave(2, dim=-1).float()
+ return (cos, sin)
+
+
+def apply_rotary_emb(x, freqs_cos_sin):
+ """Apply CogVideoX rotary embedding to query or key tensor.
+
+ x: [B, heads, seq_len, head_dim]
+ freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2]
+
+ Uses interleaved pair rotation (same as diffusers CogVideoX/Flux).
+ head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back.
+ """
+ cos, sin = freqs_cos_sin
+ cos = cos[None, None, :, :].to(x.device)
+ sin = sin[None, None, :, :].to(x.device)
+
+ # Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag)
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+
+ return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+
+def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000):
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half)
+ args = timesteps[:, None].float() * freqs[None] * scale
+ embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
+ if flip_sin_to_cos:
+ embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None):
+ if isinstance(spatial_size, int):
+ spatial_size = (spatial_size, spatial_size)
+
+ grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale
+ grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale
+ grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale
+
+ grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij")
+
+ embed_dim_spatial = 2 * (embed_dim // 3)
+ embed_dim_temporal = embed_dim // 3
+
+ pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device)
+ pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device)
+
+ T, H, W = grid_t.shape
+ pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1)
+ pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1)
+
+ return pos_embed
+
+
+def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None):
+ T, H, W = grid_h.shape
+ half_dim = embed_dim // 2
+ pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim)
+ pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim)
+ return torch.cat([pos_h, pos_w], dim=-1)
+
+
+def _get_1d_sincos_pos_embed(embed_dim, pos, device=None):
+ half = embed_dim // 2
+ freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half)
+ args = pos.float().reshape(-1)[:, None] * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if embed_dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+
+class CogVideoXPatchEmbed(nn.Module):
+ def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920,
+ text_dim=4096, bias=True, sample_width=90, sample_height=60,
+ sample_frames=49, temporal_compression_ratio=4,
+ max_text_seq_length=226, spatial_interpolation_scale=1.875,
+ temporal_interpolation_scale=1.0, use_positional_embeddings=True,
+ use_learned_positional_embeddings=True,
+ device=None, dtype=None, operations=None):
+ super().__init__()
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.dim = dim
+ self.sample_height = sample_height
+ self.sample_width = sample_width
+ self.sample_frames = sample_frames
+ self.temporal_compression_ratio = temporal_compression_ratio
+ self.max_text_seq_length = max_text_seq_length
+ self.spatial_interpolation_scale = spatial_interpolation_scale
+ self.temporal_interpolation_scale = temporal_interpolation_scale
+ self.use_positional_embeddings = use_positional_embeddings
+ self.use_learned_positional_embeddings = use_learned_positional_embeddings
+
+ if patch_size_t is None:
+ self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype)
+ else:
+ self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype)
+
+ self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype)
+
+ if use_positional_embeddings or use_learned_positional_embeddings:
+ persistent = use_learned_positional_embeddings
+ pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
+ self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
+
+ def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None):
+ post_patch_height = sample_height // self.patch_size
+ post_patch_width = sample_width // self.patch_size
+ post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
+ if self.patch_size_t is not None:
+ post_time_compression_frames = post_time_compression_frames // self.patch_size_t
+ num_patches = post_patch_height * post_patch_width * post_time_compression_frames
+
+ pos_embedding = get_3d_sincos_pos_embed(
+ self.dim,
+ (post_patch_width, post_patch_height),
+ post_time_compression_frames,
+ self.spatial_interpolation_scale,
+ self.temporal_interpolation_scale,
+ device=device,
+ )
+ pos_embedding = pos_embedding.reshape(-1, self.dim)
+ joint_pos_embedding = pos_embedding.new_zeros(
+ 1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False
+ )
+ joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
+ return joint_pos_embedding
+
+ def forward(self, text_embeds, image_embeds):
+ input_dtype = text_embeds.dtype
+ text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype)
+ batch_size, num_frames, channels, height, width = image_embeds.shape
+
+ proj_dtype = self.proj.weight.dtype
+ if self.patch_size_t is None:
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
+ image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
+ image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
+ image_embeds = image_embeds.flatten(3).transpose(2, 3)
+ image_embeds = image_embeds.flatten(1, 2)
+ else:
+ p = self.patch_size
+ p_t = self.patch_size_t
+ image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
+ image_embeds = image_embeds.reshape(
+ batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
+ )
+ image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
+ image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
+
+ embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
+
+ if self.use_positional_embeddings or self.use_learned_positional_embeddings:
+ text_seq_length = text_embeds.shape[1]
+ num_image_patches = image_embeds.shape[1]
+
+ if self.use_learned_positional_embeddings:
+ image_pos = self.pos_embedding[
+ :, self.max_text_seq_length:self.max_text_seq_length + num_image_patches
+ ].to(device=embeds.device, dtype=embeds.dtype)
+ else:
+ image_pos = get_3d_sincos_pos_embed(
+ self.dim,
+ (width // self.patch_size, height // self.patch_size),
+ num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
+ self.spatial_interpolation_scale,
+ self.temporal_interpolation_scale,
+ device=embeds.device,
+ ).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype)
+
+ # Build joint: zeros for text + sincos for image
+ joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
+ joint_pos[:, text_seq_length:] = image_pos
+ embeds = embeds + joint_pos
+
+ return embeds
+
+
+class CogVideoXLayerNormZero(nn.Module):
+ def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True,
+ device=None, dtype=None, operations=None):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype)
+ self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
+
+ def forward(self, hidden_states, encoder_hidden_states, temb):
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
+ return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
+
+
+class CogVideoXAdaLayerNorm(nn.Module):
+ def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5,
+ device=None, dtype=None, operations=None):
+ super().__init__()
+ self.silu = nn.SiLU()
+ self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype)
+ self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
+
+ def forward(self, x, temb):
+ temb = self.linear(self.silu(temb))
+ shift, scale = temb.chunk(2, dim=1)
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
+ return x
+
+
+class CogVideoXBlock(nn.Module):
+ def __init__(self, dim, num_heads, head_dim, time_dim,
+ eps=1e-5, ff_inner_dim=None, ff_bias=True,
+ device=None, dtype=None, operations=None):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = head_dim
+
+ self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
+
+ # Self-attention (joint text + latent)
+ self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
+ self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
+ self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
+ self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
+ self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
+ self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
+
+ self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
+
+ # Feed-forward (GELU approximate)
+ inner_dim = ff_inner_dim or dim * 4
+ self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype)
+ self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype)
+
+ def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options=None):
+ if transformer_options is None:
+ transformer_options = {}
+ text_seq_length = encoder_hidden_states.size(1)
+
+ # Norm & modulate
+ norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb)
+
+ # Joint self-attention
+ qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1)
+ b, s, _ = qkv_input.shape
+ n, d = self.num_heads, self.head_dim
+
+ q = self.q(qkv_input).view(b, s, n, d)
+ k = self.k(qkv_input).view(b, s, n, d)
+ v = self.v(qkv_input)
+
+ q = self.norm_q(q).view(b, s, n, d)
+ k = self.norm_k(k).view(b, s, n, d)
+
+ # Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim])
+ if image_rotary_emb is not None:
+ q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim]
+ k_img = k[:, text_seq_length:].transpose(1, 2)
+ q_img = apply_rotary_emb(q_img, image_rotary_emb)
+ k_img = apply_rotary_emb(k_img, image_rotary_emb)
+ q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1)
+ k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1)
+
+ attn_out = optimized_attention(
+ q.reshape(b, s, n * d),
+ k.reshape(b, s, n * d),
+ v,
+ heads=self.num_heads,
+ transformer_options=transformer_options,
+ )
+
+ attn_out = self.attn_out(attn_out)
+
+ attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1)
+
+ hidden_states = hidden_states + gate_msa * attn_hidden
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder
+
+ # Norm & modulate for FF
+ norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb)
+
+ # Feed-forward (GELU on concatenated text + latent)
+ ff_input = torch.cat([norm_encoder, norm_hidden], dim=1)
+ ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh"))
+
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
+
+ return hidden_states, encoder_hidden_states
+
+
+class CogVideoXTransformer3DModel(nn.Module):
+ def __init__(self,
+ num_attention_heads=30,
+ attention_head_dim=64,
+ in_channels=16,
+ out_channels=16,
+ flip_sin_to_cos=True,
+ freq_shift=0,
+ time_embed_dim=512,
+ ofs_embed_dim=None,
+ text_embed_dim=4096,
+ num_layers=30,
+ dropout=0.0,
+ attention_bias=True,
+ sample_width=90,
+ sample_height=60,
+ sample_frames=49,
+ patch_size=2,
+ patch_size_t=None,
+ temporal_compression_ratio=4,
+ max_text_seq_length=226,
+ spatial_interpolation_scale=1.875,
+ temporal_interpolation_scale=1.0,
+ use_rotary_positional_embeddings=False,
+ use_learned_positional_embeddings=False,
+ patch_bias=True,
+ image_model=None,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ dim = num_attention_heads * attention_head_dim
+ self.dim = dim
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.patch_size = patch_size
+ self.patch_size_t = patch_size_t
+ self.max_text_seq_length = max_text_seq_length
+ self.use_rotary_positional_embeddings = use_rotary_positional_embeddings
+
+ # 1. Patch embedding
+ self.patch_embed = CogVideoXPatchEmbed(
+ patch_size=patch_size,
+ patch_size_t=patch_size_t,
+ in_channels=in_channels,
+ dim=dim,
+ text_dim=text_embed_dim,
+ bias=patch_bias,
+ sample_width=sample_width,
+ sample_height=sample_height,
+ sample_frames=sample_frames,
+ temporal_compression_ratio=temporal_compression_ratio,
+ max_text_seq_length=max_text_seq_length,
+ spatial_interpolation_scale=spatial_interpolation_scale,
+ temporal_interpolation_scale=temporal_interpolation_scale,
+ use_positional_embeddings=not use_rotary_positional_embeddings,
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
+ device=device, dtype=torch.float32, operations=operations,
+ )
+
+ # 2. Time embedding
+ self.time_proj_dim = dim
+ self.time_proj_flip = flip_sin_to_cos
+ self.time_proj_shift = freq_shift
+ self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype)
+ self.time_embedding_act = nn.SiLU()
+ self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype)
+
+ # Optional OFS embedding (CogVideoX 1.5 I2V)
+ self.ofs_proj_dim = ofs_embed_dim
+ if ofs_embed_dim:
+ self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
+ self.ofs_embedding_act = nn.SiLU()
+ self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
+ else:
+ self.ofs_embedding_linear_1 = None
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList([
+ CogVideoXBlock(
+ dim=dim,
+ num_heads=num_attention_heads,
+ head_dim=attention_head_dim,
+ time_dim=time_embed_dim,
+ eps=1e-5,
+ device=device, dtype=dtype, operations=operations,
+ )
+ for _ in range(num_layers)
+ ])
+
+ self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype)
+
+ # 4. Output
+ self.norm_out = CogVideoXAdaLayerNorm(
+ time_dim=time_embed_dim, dim=dim, eps=1e-5,
+ device=device, dtype=dtype, operations=operations,
+ )
+
+ if patch_size_t is None:
+ output_dim = patch_size * patch_size * out_channels
+ else:
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
+
+ self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype)
+
+ self.spatial_interpolation_scale = spatial_interpolation_scale
+ self.temporal_interpolation_scale = temporal_interpolation_scale
+ self.temporal_compression_ratio = temporal_compression_ratio
+
+ def forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
+ if transformer_options is None:
+ transformer_options = {}
+ return comfy.patcher_extension.WrapperExecutor.new_class_executor(
+ self._forward,
+ self,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
+ ).execute(x, timestep, context, ofs, transformer_options, **kwargs)
+
+ def _forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
+ if transformer_options is None:
+ transformer_options = {}
+ # ComfyUI passes [B, C, T, H, W]
+ batch_size, channels, t, h, w = x.shape
+
+ # Pad to patch size (temporal + spatial), same pattern as WAN
+ p_t = self.patch_size_t if self.patch_size_t is not None else 1
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size))
+
+ # CogVideoX expects [B, T, C, H, W]
+ x = x.permute(0, 2, 1, 3, 4)
+ batch_size, num_frames, channels, height, width = x.shape
+
+ # Time embedding
+ t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift)
+ t_emb = t_emb.to(dtype=x.dtype)
+ emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb)))
+
+ if self.ofs_embedding_linear_1 is not None and ofs is not None:
+ ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift)
+ ofs_emb = ofs_emb.to(dtype=x.dtype)
+ ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb)))
+ emb = emb + ofs_emb
+
+ # Patch embedding
+ hidden_states = self.patch_embed(context, x)
+
+ text_seq_length = context.shape[1]
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
+ hidden_states = hidden_states[:, text_seq_length:]
+
+ # Rotary embeddings (if used)
+ image_rotary_emb = None
+ if self.use_rotary_positional_embeddings:
+ post_patch_height = height // self.patch_size
+ post_patch_width = width // self.patch_size
+ if self.patch_size_t is None:
+ post_time = num_frames
+ else:
+ post_time = num_frames // self.patch_size_t
+ image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device)
+
+ # Transformer blocks
+ for i, block in enumerate(self.blocks):
+ hidden_states, encoder_hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=emb,
+ image_rotary_emb=image_rotary_emb,
+ transformer_options=transformer_options,
+ )
+
+ hidden_states = self.norm_final(hidden_states)
+
+ # Output projection
+ hidden_states = self.norm_out(hidden_states, temb=emb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # Unpatchify
+ p = self.patch_size
+ p_t = self.patch_size_t
+
+ if p_t is None:
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
+ else:
+ output = hidden_states.reshape(
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
+ )
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
+
+ # Back to ComfyUI format [B, C, T, H, W] and crop padding
+ output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w]
+ return output
+
+ def _get_rotary_emb(self, h, w, t, device):
+ """Compute CogVideoX 3D rotary positional embeddings.
+
+ For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode — grid positions
+ are integer arange computed at max_size, then sliced to actual size.
+ For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords
+ scaled by spatial_interpolation_scale.
+ """
+ d = self.attention_head_dim
+ dim_t = d // 4
+ dim_h = d // 8 * 3
+ dim_w = d // 8 * 3
+
+ if self.patch_size_t is not None:
+ # CogVideoX 1.5: "slice" mode — positions are simple integer indices
+ # Compute at max(sample_size, actual_size) then slice to actual
+ base_h = self.patch_embed.sample_height // self.patch_size
+ base_w = self.patch_embed.sample_width // self.patch_size
+ max_h = max(base_h, h)
+ max_w = max(base_w, w)
+
+ grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
+ grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
+ grid_t = torch.arange(t, device=device, dtype=torch.float32)
+ else:
+ # CogVideoX 1.0: "linspace" mode with interpolation scale
+ grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
+ grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
+ grid_t = torch.arange(t, device=device, dtype=torch.float32)
+
+ freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t)
+ freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h)
+ freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w)
+
+ t_cos, t_sin = freqs_t
+ h_cos, h_sin = freqs_h
+ w_cos, w_sin = freqs_w
+
+ # Slice to actual size (for "slice" mode where grids may be larger)
+ t_cos, t_sin = t_cos[:t], t_sin[:t]
+ h_cos, h_sin = h_cos[:h], h_sin[:h]
+ w_cos, w_sin = w_cos[:w], w_sin[:w]
+
+ # Broadcast and concatenate into [T*H*W, head_dim]
+ t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1)
+ t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1)
+ h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1)
+ h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1)
+ w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1)
+ w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1)
+
+ cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1)
+ sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1)
+ return (cos, sin)
diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py
new file mode 100644
index 000000000..d4e6f321e
--- /dev/null
+++ b/comfy/ldm/cogvideo/vae.py
@@ -0,0 +1,566 @@
+# CogVideoX VAE - ported to ComfyUI native ops
+# Architecture reference: diffusers AutoencoderKLCogVideoX
+# Style reference: comfy/ldm/wan/vae.py
+
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+
+class CausalConv3d(nn.Module):
+ """Causal 3D convolution with temporal padding.
+
+ Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has
+ a single temporal frame and no cache, the 3D conv weight is sliced to act
+ as a 2D conv, avoiding computation on zero-padded temporal dimensions.
+ """
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"):
+ super().__init__()
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size,) * 3
+
+ time_kernel, height_kernel, width_kernel = kernel_size
+ self.time_kernel_size = time_kernel
+ self.pad_mode = pad_mode
+
+ height_pad = (height_kernel - 1) // 2
+ width_pad = (width_kernel - 1) // 2
+ self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0)
+
+ stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
+ dilation = (dilation, 1, 1)
+ self.conv = ops.Conv3d(
+ in_channels, out_channels, kernel_size,
+ stride=stride, dilation=dilation,
+ padding=(0, height_pad, width_pad),
+ )
+
+ def forward(self, x, conv_cache=None):
+ if self.pad_mode == "replicate":
+ x = F.pad(x, self.time_causal_padding, mode="replicate")
+ conv_cache = None
+ else:
+ kernel_t = self.time_kernel_size
+ if kernel_t > 1:
+ if conv_cache is None and x.shape[2] == 1:
+ # Fast path: single frame, no cache. All temporal padding
+ # frames are copies of the input (replicate-style), so the
+ # 3D conv reduces to a 2D conv with summed temporal kernel.
+ w = comfy.ops.cast_to_input(self.conv.weight, x)
+ b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None
+ w2d = w.sum(dim=2, keepdim=True)
+ out = F.conv3d(x, w2d, b,
+ self.conv.stride, self.conv.padding,
+ self.conv.dilation, self.conv.groups)
+ return out, None
+ cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1)
+ x = torch.cat(cached + [x], dim=2)
+ conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None
+
+ out = self.conv(x)
+ return out, conv_cache
+
+
+def _interpolate_zq(zq, target_size):
+ """Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling."""
+ t = target_size[0]
+ if t > 1 and t % 2 == 1:
+ z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2]))
+ z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2]))
+ return torch.cat([z_first, z_rest], dim=2)
+ return F.interpolate(zq, size=target_size)
+
+
+class SpatialNorm3D(nn.Module):
+ """Spatially conditioned normalization."""
+ def __init__(self, f_channels, zq_channels, groups=32):
+ super().__init__()
+ self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
+ self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+ self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
+
+ def forward(self, f, zq, conv_cache=None):
+ new_cache = {}
+ conv_cache = conv_cache or {}
+
+ if zq.shape[-3:] != f.shape[-3:]:
+ zq = _interpolate_zq(zq, f.shape[-3:])
+
+ conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
+ conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
+
+ return self.norm_layer(f) * conv_y + conv_b, new_cache
+
+
+class ResnetBlock3D(nn.Module):
+ """3D ResNet block with optional spatial norm."""
+ def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32,
+ eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"):
+ super().__init__()
+ out_channels = out_channels or in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.spatial_norm_dim = spatial_norm_dim
+
+ if act_fn == "silu":
+ self.nonlinearity = nn.SiLU()
+ elif act_fn == "swish":
+ self.nonlinearity = nn.SiLU()
+ else:
+ self.nonlinearity = nn.SiLU()
+
+ if spatial_norm_dim is None:
+ self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
+ self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
+ else:
+ self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
+ self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
+
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+
+ if temb_channels > 0:
+ self.temb_proj = ops.Linear(temb_channels, out_channels)
+
+ self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
+
+ if in_channels != out_channels:
+ self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ else:
+ self.conv_shortcut = None
+
+ def forward(self, x, temb=None, zq=None, conv_cache=None):
+ new_cache = {}
+ conv_cache = conv_cache or {}
+ residual = x
+
+ if zq is not None:
+ x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1"))
+ else:
+ x = self.norm1(x)
+
+ x = self.nonlinearity(x)
+ x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1"))
+
+ if temb is not None and hasattr(self, "temb_proj"):
+ x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if zq is not None:
+ x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2"))
+ else:
+ x = self.norm2(x)
+
+ x = self.nonlinearity(x)
+ x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2"))
+
+ if self.conv_shortcut is not None:
+ residual = self.conv_shortcut(residual)
+
+ return x + residual, new_cache
+
+
+class Downsample3D(nn.Module):
+ """3D downsampling with optional temporal compression."""
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
+ super().__init__()
+ self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.compress_time = compress_time
+
+ def forward(self, x):
+ if self.compress_time:
+ b, c, t, h, w = x.shape
+ x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
+ if t % 2 == 1:
+ x_first, x_rest = x[..., 0], x[..., 1:]
+ if x_rest.shape[-1] > 0:
+ x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
+ x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
+ else:
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
+ x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
+
+ pad = (0, 1, 0, 1)
+ x = F.pad(x, pad, mode="constant", value=0)
+ b, c, t, h, w = x.shape
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ x = self.conv(x)
+ x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
+ return x
+
+
+class Upsample3D(nn.Module):
+ """3D upsampling with optional temporal decompression."""
+ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
+ super().__init__()
+ self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.compress_time = compress_time
+
+ def forward(self, x):
+ if self.compress_time:
+ if x.shape[2] > 1 and x.shape[2] % 2 == 1:
+ x_first, x_rest = x[:, :, 0], x[:, :, 1:]
+ x_first = F.interpolate(x_first, scale_factor=2.0)
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
+ x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
+ elif x.shape[2] > 1:
+ x = F.interpolate(x, scale_factor=2.0)
+ else:
+ x = x.squeeze(2)
+ x = F.interpolate(x, scale_factor=2.0)
+ x = x[:, :, None, :, :]
+ else:
+ b, c, t, h, w = x.shape
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ x = F.interpolate(x, scale_factor=2.0)
+ x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4)
+
+ b, c, t, h, w = x.shape
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
+ x = self.conv(x)
+ x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4)
+ return x
+
+
+class DownBlock3D(nn.Module):
+ def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
+ eps=1e-6, act_fn="silu", groups=32, add_downsample=True,
+ compress_time=False, pad_mode="first"):
+ super().__init__()
+ self.resnets = nn.ModuleList([
+ ResnetBlock3D(
+ in_channels=in_channels if i == 0 else out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode,
+ )
+ for i in range(num_layers)
+ ])
+ self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None
+
+ def forward(self, x, temb=None, zq=None, conv_cache=None):
+ new_cache = {}
+ conv_cache = conv_cache or {}
+ for i, resnet in enumerate(self.resnets):
+ x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
+ if self.downsamplers is not None:
+ for ds in self.downsamplers:
+ x = ds(x)
+ return x, new_cache
+
+
+class MidBlock3D(nn.Module):
+ def __init__(self, in_channels, temb_channels=0, num_layers=1,
+ eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"):
+ super().__init__()
+ self.resnets = nn.ModuleList([
+ ResnetBlock3D(
+ in_channels=in_channels, out_channels=in_channels,
+ temb_channels=temb_channels, groups=groups, eps=eps,
+ act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
+ )
+ for _ in range(num_layers)
+ ])
+
+ def forward(self, x, temb=None, zq=None, conv_cache=None):
+ new_cache = {}
+ conv_cache = conv_cache or {}
+ for i, resnet in enumerate(self.resnets):
+ x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
+ return x, new_cache
+
+
+class UpBlock3D(nn.Module):
+ def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
+ eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16,
+ add_upsample=True, compress_time=False, pad_mode="first"):
+ super().__init__()
+ self.resnets = nn.ModuleList([
+ ResnetBlock3D(
+ in_channels=in_channels if i == 0 else out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels, groups=groups, eps=eps,
+ act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
+ )
+ for i in range(num_layers)
+ ])
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None
+
+ def forward(self, x, temb=None, zq=None, conv_cache=None):
+ new_cache = {}
+ conv_cache = conv_cache or {}
+ for i, resnet in enumerate(self.resnets):
+ x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
+ if self.upsamplers is not None:
+ for us in self.upsamplers:
+ x = us(x)
+ return x, new_cache
+
+
+class Encoder3D(nn.Module):
+ def __init__(self, in_channels=3, out_channels=16,
+ block_out_channels=(128, 256, 256, 512),
+ layers_per_block=3, act_fn="silu",
+ eps=1e-6, groups=32, pad_mode="first",
+ temporal_compression_ratio=4):
+ super().__init__()
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
+
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
+
+ self.down_blocks = nn.ModuleList()
+ output_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final = i == len(block_out_channels) - 1
+ compress_time = i < temporal_compress_level
+
+ self.down_blocks.append(DownBlock3D(
+ in_channels=input_channel, out_channels=output_channel,
+ temb_channels=0, num_layers=layers_per_block,
+ eps=eps, act_fn=act_fn, groups=groups,
+ add_downsample=not is_final, compress_time=compress_time,
+ ))
+
+ self.mid_block = MidBlock3D(
+ in_channels=block_out_channels[-1], temb_channels=0,
+ num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
+ )
+
+ self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
+
+ def forward(self, x, conv_cache=None):
+ new_cache = {}
+ conv_cache = conv_cache or {}
+
+ x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in"))
+
+ for i, block in enumerate(self.down_blocks):
+ key = f"down_block_{i}"
+ x, new_cache[key] = block(x, None, None, conv_cache.get(key))
+
+ x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block"))
+
+ x = self.norm_out(x)
+ x = self.conv_act(x)
+ x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
+
+ return x, new_cache
+
+
+class Decoder3D(nn.Module):
+ def __init__(self, in_channels=16, out_channels=3,
+ block_out_channels=(128, 256, 256, 512),
+ layers_per_block=3, act_fn="silu",
+ eps=1e-6, groups=32, pad_mode="first",
+ temporal_compression_ratio=4):
+ super().__init__()
+ reversed_channels = list(reversed(block_out_channels))
+ temporal_compress_level = int(np.log2(temporal_compression_ratio))
+
+ self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode)
+
+ self.mid_block = MidBlock3D(
+ in_channels=reversed_channels[0], temb_channels=0,
+ num_layers=2, eps=eps, act_fn=act_fn, groups=groups,
+ spatial_norm_dim=in_channels, pad_mode=pad_mode,
+ )
+
+ self.up_blocks = nn.ModuleList()
+ output_channel = reversed_channels[0]
+ for i in range(len(block_out_channels)):
+ prev_channel = output_channel
+ output_channel = reversed_channels[i]
+ is_final = i == len(block_out_channels) - 1
+ compress_time = i < temporal_compress_level
+
+ self.up_blocks.append(UpBlock3D(
+ in_channels=prev_channel, out_channels=output_channel,
+ temb_channels=0, num_layers=layers_per_block + 1,
+ eps=eps, act_fn=act_fn, groups=groups,
+ spatial_norm_dim=in_channels,
+ add_upsample=not is_final, compress_time=compress_time,
+ ))
+
+ self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups)
+ self.conv_act = nn.SiLU()
+ self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode)
+
+ def forward(self, sample, conv_cache=None):
+ new_cache = {}
+ conv_cache = conv_cache or {}
+
+ x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
+
+ x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block"))
+
+ for i, block in enumerate(self.up_blocks):
+ key = f"up_block_{i}"
+ x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key))
+
+ x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out"))
+ x = self.conv_act(x)
+ x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
+
+ return x, new_cache
+
+
+
+class AutoencoderKLCogVideoX(nn.Module):
+ """CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper.
+
+ Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run
+ on the full (low-res) tensor, then the expensive spatial-only up_blocks +
+ norm_out + conv_out are processed in small temporal chunks with conv_cache
+ carrying causal state between chunks. This keeps peak VRAM proportional to
+ chunk_size rather than total frame count.
+ """
+
+ def __init__(self,
+ in_channels=3, out_channels=3,
+ block_out_channels=(128, 256, 256, 512),
+ latent_channels=16, layers_per_block=3,
+ act_fn="silu", eps=1e-6, groups=32,
+ temporal_compression_ratio=4,
+ ):
+ super().__init__()
+ self.latent_channels = latent_channels
+ self.temporal_compression_ratio = temporal_compression_ratio
+
+ self.encoder = Encoder3D(
+ in_channels=in_channels, out_channels=latent_channels,
+ block_out_channels=block_out_channels, layers_per_block=layers_per_block,
+ act_fn=act_fn, eps=eps, groups=groups,
+ temporal_compression_ratio=temporal_compression_ratio,
+ )
+ self.decoder = Decoder3D(
+ in_channels=latent_channels, out_channels=out_channels,
+ block_out_channels=block_out_channels, layers_per_block=layers_per_block,
+ act_fn=act_fn, eps=eps, groups=groups,
+ temporal_compression_ratio=temporal_compression_ratio,
+ )
+
+ self.num_latent_frames_batch_size = 2
+ self.num_sample_frames_batch_size = 8
+
+ def encode(self, x):
+ t = x.shape[2]
+ frame_batch = self.num_sample_frames_batch_size
+ remainder = t % frame_batch
+ conv_cache = None
+ enc = []
+
+ # Process remainder frames first so only the first chunk can have an
+ # odd temporal dimension — where Downsample3D's first-frame-special
+ # handling in temporal compression is actually correct.
+ if remainder > 0:
+ chunk, conv_cache = self.encoder(x[:, :, :remainder], conv_cache=conv_cache)
+ enc.append(chunk.to(x.device))
+
+ for start in range(remainder, t, frame_batch):
+ chunk, conv_cache = self.encoder(x[:, :, start:start + frame_batch], conv_cache=conv_cache)
+ enc.append(chunk.to(x.device))
+
+ enc = torch.cat(enc, dim=2)
+ mean, _ = enc.chunk(2, dim=1)
+ return mean
+
+ def decode(self, z):
+ return self._decode_rolling(z)
+
+ def _decode_batched(self, z):
+ """Original batched decode - processes 2 latent frames through full decoder."""
+ t = z.shape[2]
+ frame_batch = self.num_latent_frames_batch_size
+ num_batches = max(t // frame_batch, 1)
+ conv_cache = None
+ dec = []
+ for i in range(num_batches):
+ remaining = t % frame_batch
+ start = frame_batch * i + (0 if i == 0 else remaining)
+ end = frame_batch * (i + 1) + remaining
+ chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache)
+ dec.append(chunk.cpu())
+ return torch.cat(dec, dim=2).to(z.device)
+
+ def _decode_rolling(self, z):
+ """Rolling decode - processes low-res layers on full tensor, then rolls
+ through expensive high-res layers in temporal chunks."""
+ decoder = self.decoder
+ device = z.device
+
+ # Determine which up_blocks have temporal upsample vs spatial-only.
+ # Temporal up_blocks are cheap (low res), spatial-only are expensive.
+ temporal_compress_level = int(np.log2(self.temporal_compression_ratio))
+ split_at = temporal_compress_level # first N up_blocks do temporal upsample
+
+ # Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res)
+ x, _ = decoder.conv_in(z)
+ x, _ = decoder.mid_block(x, None, z)
+
+ for i in range(split_at):
+ x, _ = decoder.up_blocks[i](x, None, z)
+
+ # Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks
+ remaining_blocks = list(range(split_at, len(decoder.up_blocks)))
+ chunk_size = 4 # pixel frames per chunk through high-res layers
+ t_expanded = x.shape[2]
+
+ if t_expanded <= chunk_size or len(remaining_blocks) == 0:
+ # Small enough to process in one go
+ for i in remaining_blocks:
+ x, _ = decoder.up_blocks[i](x, None, z)
+ x, _ = decoder.norm_out(x, z)
+ x = decoder.conv_act(x)
+ x, _ = decoder.conv_out(x)
+ return x
+
+ # Expand z temporally once to match Phase 2's time dimension.
+ # z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB
+ # for the old approach of pre-interpolating to every pixel resolution).
+ z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4]))
+
+ # Process in temporal chunks, interpolating spatially per-chunk to avoid
+ # allocating full [B, C, t_expanded, H, W] tensors at each resolution.
+ dec_out = []
+ conv_caches = {}
+
+ for chunk_start in range(0, t_expanded, chunk_size):
+ chunk_end = min(chunk_start + chunk_size, t_expanded)
+ x_chunk = x[:, :, chunk_start:chunk_end]
+ z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end]
+ z_spatial_cache = {}
+
+ for i in remaining_blocks:
+ block = decoder.up_blocks[i]
+ cache_key = f"up_block_{i}"
+ hw_key = (x_chunk.shape[3], x_chunk.shape[4])
+ if hw_key not in z_spatial_cache:
+ if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]:
+ z_spatial_cache[hw_key] = z_t_chunk
+ else:
+ z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
+ x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key))
+ conv_caches[cache_key] = new_cache
+
+ hw_key = (x_chunk.shape[3], x_chunk.shape[4])
+ if hw_key not in z_spatial_cache:
+ z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
+ x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out"))
+ conv_caches["norm_out"] = new_cache
+ x_chunk = decoder.conv_act(x_chunk)
+ x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
+ conv_caches["conv_out"] = new_cache
+
+ dec_out.append(x_chunk.cpu())
+ del z_spatial_cache
+
+ del x, z_time_expanded
+ return torch.cat(dec_out, dim=2).to(device)
diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py
index 6f2ba41ef..3fb87b4a3 100644
--- a/comfy/ldm/lightricks/av_model.py
+++ b/comfy/ldm/lightricks/av_model.py
@@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit
+import comfy.model_prefetch
class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing."""
@@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
+ prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
# Process transformer blocks
for i, block in enumerate(self.transformer_blocks):
+ comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace:
def block_wrap(args):
@@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep,
)
+ comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
+
return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index b193fe5e8..a68cb8439 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
+TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
+
if model_management.xformers_enabled():
import xformers
import xformers.ops
@@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
- scale = dim_head ** -0.5
+ if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
+ n_rep = q.shape[-3] // k.shape[-3]
+ k = k.repeat_interleave(n_rep, dim=-3)
+ v = v.repeat_interleave(n_rep, dim=-3)
+
+ scale = kwargs.get("scale", dim_head ** -0.5)
h = heads
if skip_reshape:
@@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape
dim_head //= heads
+ if "scale" in kwargs:
+ # Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
+ query = query * (kwargs["scale"] * dim_head ** 0.5)
+
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
@@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
- scale = dim_head ** -0.5
+ scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape:
q, k, v = map(
@@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
+ # Pass through extra SDPA kwargs (scale, enable_gqa) if provided
+ # enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
+ sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
+ sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
+
if SDP_BATCH_LIMIT >= b:
- out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
+ out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
- dropout_p=0.0, is_causal=False
+ dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out
diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py
index 233011dc9..aed5c149c 100644
--- a/comfy/ldm/modules/diffusionmodules/util.py
+++ b/comfy/ldm/modules/diffusionmodules/util.py
@@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
- # according the the formula provided in https://arxiv.org/abs/2010.02502
+ # according to the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
diff --git a/comfy/ldm/sam3/detector.py b/comfy/ldm/sam3/detector.py
index 12d3a01ab..23a972ac7 100644
--- a/comfy/ldm/sam3/detector.py
+++ b/comfy/ldm/sam3/detector.py
@@ -561,7 +561,8 @@ class SAM3Model(nn.Module):
return high_res_masks
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
- new_det_thresh=0.5, max_objects=0, detect_interval=1):
+ new_det_thresh=0.5, max_objects=0, detect_interval=1,
+ target_device=None, target_dtype=None):
"""Track video with optional per-frame text-prompted detection."""
bb = self.detector.backbone["vision_backbone"]
@@ -589,8 +590,10 @@ class SAM3Model(nn.Module):
return self.tracker.track_video_with_detection(
backbone_fn, images, initial_masks, detect_fn,
new_det_thresh=new_det_thresh, max_objects=max_objects,
- detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
+ detect_interval=detect_interval, backbone_obj=bb, pbar=pbar,
+ target_device=target_device, target_dtype=target_dtype)
# SAM3 (non-multiplex) — no detection support, requires initial masks
if initial_masks is None:
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
- return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
+ return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb,
+ target_device=target_device, target_dtype=target_dtype)
diff --git a/comfy/ldm/sam3/tracker.py b/comfy/ldm/sam3/tracker.py
index 8f7481003..8456e90a6 100644
--- a/comfy/ldm/sam3/tracker.py
+++ b/comfy/ldm/sam3/tracker.py
@@ -200,8 +200,13 @@ def pack_masks(masks):
def unpack_masks(packed):
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
- shifts = torch.arange(8, device=packed.device)
- return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
+ bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device)
+ return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1)
+
+
+def _prep_frame(images, idx, device, dt, size):
+ """Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size)."""
+ return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled")
def _compute_backbone(backbone_fn, frame, frame_idx=None):
@@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module):
# SAM3: drop last FPN level
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
- def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
+ def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None,
+ target_device=None, target_dtype=None):
"""Track one object, computing backbone per frame to save VRAM."""
N = images.shape[0]
- device, dt = images.device, images.dtype
+ device = target_device if target_device is not None else images.device
+ dt = target_dtype if target_dtype is not None else images.dtype
+ size = self.image_size
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
all_masks = []
for frame_idx in tqdm(range(N), desc="tracking"):
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
- backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
+ backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx)
mask_input = None
if frame_idx == 0:
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
@@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module):
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
- def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
+ def track_video(self, backbone_fn, images, initial_masks, pbar=None,
+ target_device=None, target_dtype=None, **kwargs):
"""Track one or more objects across video frames.
Args:
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
- images: [N, 3, 1008, 1008] video frames
+ images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size)
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
pbar: optional progress bar
@@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module):
per_object = []
for obj_idx in range(N_obj):
obj_masks = self._track_single_object(
- backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
+ backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar,
+ target_device=target_device, target_dtype=target_dtype)
per_object.append(obj_masks)
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
@@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module):
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
return []
+ INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here.
+
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1,
- backbone_obj=None, pbar=None):
+ backbone_obj=None, pbar=None, target_device=None, target_dtype=None):
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
- N, device, dt = images.shape[0], images.device, images.dtype
+ if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS:
+ max_objects = self.INTERNAL_MAX_OBJECTS
+ N = images.shape[0]
+ device = target_device if target_device is not None else images.device
+ dt = target_dtype if target_dtype is not None else images.dtype
+ size = self.image_size
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
all_masks = []
idev = comfy.model_management.intermediate_device()
@@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module):
prefetch = True
except RuntimeError:
pass
- cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
+ cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0)
for frame_idx in tqdm(range(N), desc="tracking"):
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
@@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module):
backbone_stream.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(backbone_stream):
next_bb = self._compute_backbone_frame(
- backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
+ backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
det_masks = torch.empty(0, device=device)
@@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module):
current_out = self._condition_with_masks(
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
- images[frame_idx:frame_idx + 1], trunk_out)
+ _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out)
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
obj_scores = [1.0] * mux_state.total_valid_entries
if keep_alive is not None:
@@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module):
current_out = self._condition_with_masks(
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
output_dict, N, mux_state, backbone_obj,
- images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
+ _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0)
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
if keep_alive is not None:
@@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module):
torch.cuda.current_stream(device).wait_stream(backbone_stream)
cur_bb = next_bb
else:
- cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
+ cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
continue
else:
N_obj = mux_state.total_valid_entries
@@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module):
torch.cuda.current_stream(device).wait_stream(backbone_stream)
cur_bb = next_bb
else:
- cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
+ cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
if not all_masks or all(m is None for m in all_masks):
return {"packed_masks": None, "n_frames": N, "scores": []}
diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py
new file mode 100644
index 000000000..d72f53602
--- /dev/null
+++ b/comfy/ldm/wan/ar_model.py
@@ -0,0 +1,276 @@
+"""
+CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
+autoregressive (frame-by-frame) video generation via Causal Forcing.
+
+Weight-compatible with the standard WanModel -- same layer names, same shapes.
+The difference is purely in the forward pass: this model processes one temporal
+block at a time and maintains a KV cache across blocks.
+
+Reference: https://github.com/thu-ml/Causal-Forcing
+"""
+
+import torch
+import torch.nn as nn
+
+from comfy.ldm.modules.attention import optimized_attention
+from comfy.ldm.flux.math import apply_rope1
+from comfy.ldm.wan.model import (
+ sinusoidal_embedding_1d,
+ repeat_e,
+ WanModel,
+ WanAttentionBlock,
+)
+import comfy.ldm.common_dit
+import comfy.model_management
+
+
+class CausalWanSelfAttention(nn.Module):
+ """Self-attention with KV cache support for autoregressive inference."""
+
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
+ eps=1e-6, operation_settings={}):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.qk_norm = qk_norm
+ self.eps = eps
+
+ ops = operation_settings.get("operations")
+ device = operation_settings.get("device")
+ dtype = operation_settings.get("dtype")
+
+ self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
+ self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
+ self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
+ self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
+ self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
+ self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
+
+ def forward(self, x, freqs, kv_cache=None, transformer_options={}):
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+ q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
+ k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
+ v = self.v(x).view(b, s, n, d)
+
+ if kv_cache is None:
+ x = optimized_attention(
+ q.view(b, s, n * d),
+ k.view(b, s, n * d),
+ v.view(b, s, n * d),
+ heads=self.num_heads,
+ transformer_options=transformer_options,
+ )
+ else:
+ end = kv_cache["end"]
+ new_end = end + s
+
+ # Roped K and plain V go into cache
+ kv_cache["k"][:, end:new_end] = k
+ kv_cache["v"][:, end:new_end] = v
+ kv_cache["end"] = new_end
+
+ x = optimized_attention(
+ q.view(b, s, n * d),
+ kv_cache["k"][:, :new_end].view(b, new_end, n * d),
+ kv_cache["v"][:, :new_end].view(b, new_end, n * d),
+ heads=self.num_heads,
+ transformer_options=transformer_options,
+ )
+
+ x = self.o(x)
+ return x
+
+
+class CausalWanAttentionBlock(WanAttentionBlock):
+ """Transformer block with KV-cached self-attention and cross-attention caching."""
+
+ def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
+ window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
+ eps=1e-6, operation_settings={}):
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
+ window_size, qk_norm, cross_attn_norm, eps,
+ operation_settings=operation_settings)
+ self.self_attn = CausalWanSelfAttention(
+ dim, num_heads, window_size, qk_norm, eps,
+ operation_settings=operation_settings)
+
+ def forward(self, x, e, freqs, context, context_img_len=257,
+ kv_cache=None, crossattn_cache=None, transformer_options={}):
+ if e.ndim < 4:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
+ else:
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
+
+ # Self-attention with optional KV cache
+ x = x.contiguous()
+ y = self.self_attn(
+ torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
+ freqs, kv_cache=kv_cache, transformer_options=transformer_options)
+ x = torch.addcmul(x, y, repeat_e(e[2], x))
+ del y
+
+ # Cross-attention with optional caching
+ if crossattn_cache is not None and crossattn_cache.get("is_init"):
+ q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
+ x_ca = optimized_attention(
+ q, crossattn_cache["k"], crossattn_cache["v"],
+ heads=self.num_heads, transformer_options=transformer_options)
+ x = x + self.cross_attn.o(x_ca)
+ else:
+ x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
+ if crossattn_cache is not None:
+ crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
+ crossattn_cache["v"] = self.cross_attn.v(context)
+ crossattn_cache["is_init"] = True
+
+ # FFN
+ y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
+ x = torch.addcmul(x, y, repeat_e(e[5], x))
+ return x
+
+
+class CausalWanModel(WanModel):
+ """
+ Wan 2.1 diffusion backbone with causal KV-cache support.
+
+ Same weight structure as WanModel -- loads identical state dicts.
+ Adds forward_block() for frame-by-frame autoregressive inference.
+ """
+
+ def __init__(self,
+ model_type='t2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ image_model=None,
+ device=None,
+ dtype=None,
+ operations=None):
+ super().__init__(
+ model_type=model_type, patch_size=patch_size, text_len=text_len,
+ in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
+ text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
+ num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
+ cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
+ wan_attn_block_class=CausalWanAttentionBlock,
+ device=device, dtype=dtype, operations=operations)
+
+ def forward_block(self, x, timestep, context, start_frame,
+ kv_caches, crossattn_caches, clip_fea=None):
+ """
+ Forward one temporal block for autoregressive inference.
+
+ Args:
+ x: [B, C, block_frames, H, W] input latent for the current block
+ timestep: [B, block_frames] per-frame timesteps
+ context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
+ start_frame: temporal frame index for RoPE offset
+ kv_caches: list of per-layer KV cache dicts
+ crossattn_caches: list of per-layer cross-attention cache dicts
+ clip_fea: optional CLIP features for I2V
+
+ Returns:
+ flow_pred: [B, C_out, block_frames, H, W] flow prediction
+ """
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
+ bs, c, t, h, w = x.shape
+
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+
+ # Per-frame time embedding
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
+ e = e.reshape(timestep.shape[0], -1, e.shape[-1])
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
+
+ # Text embedding (reuses crossattn_cache after first block)
+ context = self.text_embedding(context)
+
+ context_img_len = None
+ if clip_fea is not None and self.img_emb is not None:
+ context_clip = self.img_emb(clip_fea)
+ context = torch.concat([context_clip, context], dim=1)
+ context_img_len = clip_fea.shape[-2]
+
+ # RoPE for current block's temporal position
+ freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
+
+ # Transformer blocks
+ for i, block in enumerate(self.blocks):
+ x = block(x, e=e0, freqs=freqs, context=context,
+ context_img_len=context_img_len,
+ kv_cache=kv_caches[i],
+ crossattn_cache=crossattn_caches[i])
+
+ # Head
+ x = self.head(x, e)
+
+ # Unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x[:, :, :t, :h, :w]
+
+ def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
+ """Create fresh KV caches for all layers."""
+ caches = []
+ for _ in range(self.num_layers):
+ caches.append({
+ "k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
+ "v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
+ "end": 0,
+ })
+ return caches
+
+ def init_crossattn_caches(self, batch_size, device, dtype):
+ """Create fresh cross-attention caches for all layers."""
+ caches = []
+ for _ in range(self.num_layers):
+ caches.append({"is_init": False})
+ return caches
+
+ def reset_kv_caches(self, kv_caches):
+ """Reset KV caches to empty (reuse allocated memory)."""
+ for cache in kv_caches:
+ cache["end"] = 0
+
+ def reset_crossattn_caches(self, crossattn_caches):
+ """Reset cross-attention caches."""
+ for cache in crossattn_caches:
+ cache["is_init"] = False
+
+ @property
+ def head_dim(self):
+ return self.dim // self.num_heads
+
+ def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
+ ar_state = transformer_options.get("ar_state")
+ if ar_state is not None:
+ bs = x.shape[0]
+ block_frames = x.shape[2]
+ t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
+ return self.forward_block(
+ x=x, timestep=t_per_frame, context=context,
+ start_frame=ar_state["start_frame"],
+ kv_caches=ar_state["kv_caches"],
+ crossattn_caches=ar_state["crossattn_caches"],
+ clip_fea=clip_fea,
+ )
+
+ return super().forward(x, timestep, context, clip_fea=clip_fea,
+ time_dim_concat=time_dim_concat,
+ transformer_options=transformer_options, **kwargs)
diff --git a/comfy/lora.py b/comfy/lora.py
index 63ee85323..db8f16bcb 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -17,6 +17,7 @@
"""
from __future__ import annotations
+import comfy.memory_management
import comfy.utils
import comfy.model_management
import comfy.model_base
@@ -342,6 +343,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
+ if isinstance(model, comfy.model_base.ErnieImage):
+ for k in sdk:
+ if k.startswith("diffusion_model.") and k.endswith(".weight"):
+ key_lora = k[len("diffusion_model."):-len(".weight")]
+ key_map["transformer.{}".format(key_lora)] = k
+
return key_map
@@ -467,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
weight = old_weight
return weight
+
+def prefetch_prepared_value(value, allocate_buffer, stream):
+ if isinstance(value, torch.Tensor):
+ dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
+ comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
+ return comfy.memory_management.interpret_gathered_like([value], dest)[0]
+ elif isinstance(value, weight_adapter.WeightAdapterBase):
+ return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
+ elif isinstance(value, tuple):
+ return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
+ elif isinstance(value, list):
+ return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
+
+ return value
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 787ea1145..57a1e44d2 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate
+import comfy.ldm.wan.ar_model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
@@ -52,6 +53,7 @@ import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
+import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
@@ -81,6 +83,7 @@ class ModelType(Enum):
IMG_TO_IMG = 9
FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
+ V_PREDICTION_DDPM = 12
def model_sampling(model_config, model_type):
@@ -115,6 +118,8 @@ def model_sampling(model_config, model_type):
s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
+ elif model_type == ModelType.V_PREDICTION_DDPM:
+ c = comfy.model_sampling.V_PREDICTION_DDPM
class ModelSampling(s, c):
pass
@@ -210,6 +215,11 @@ class BaseModel(torch.nn.Module):
if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
+ transformer_options = transformer_options.copy()
+ transformer_options["prefetch_dynamic_vbars"] = (
+ self.current_patcher is not None and self.current_patcher.is_dynamic()
+ )
+
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
@@ -1356,6 +1366,13 @@ class WAN21(BaseModel):
return out
+class WAN21_CausalAR(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device,
+ unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
+ self.image_to_video = False
+
+
class WAN21_Vace(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
@@ -1979,3 +1996,59 @@ class ErnieImage(BaseModel):
class SAM3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)
+
+class CogVideoX(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel)
+ self.image_to_video = image_to_video
+
+ def concat_cond(self, **kwargs):
+ noise = kwargs.get("noise", None)
+ # Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent)
+ extra_channels = self.diffusion_model.in_channels - noise.shape[1]
+ if extra_channels == 0:
+ return None
+
+ image = kwargs.get("concat_latent_image", None)
+ device = kwargs["device"]
+
+ if image is None:
+ shape = list(noise.shape)
+ shape[1] = extra_channels
+ return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device)
+
+ latent_dim = self.latent_format.latent_channels
+ image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+
+ if noise.ndim == 5 and image.ndim == 5:
+ if image.shape[-3] < noise.shape[-3]:
+ image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0)
+ elif image.shape[-3] > noise.shape[-3]:
+ image = image[:, :, :noise.shape[-3]]
+
+ for i in range(0, image.shape[1], latent_dim):
+ image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim])
+ image = utils.resize_to_batch_size(image, noise.shape[0])
+
+ if image.shape[1] > extra_channels:
+ image = image[:, :extra_channels]
+ elif image.shape[1] < extra_channels:
+ repeats = extra_channels // image.shape[1]
+ remainder = extra_channels % image.shape[1]
+ parts = [image] * repeats
+ if remainder > 0:
+ parts.append(image[:, :remainder])
+ image = torch.cat(parts, dim=1)
+
+ return image
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ # OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR
+ if self.diffusion_model.ofs_proj_dim is not None:
+ ofs = kwargs.get("ofs", None)
+ if ofs is None:
+ noise = kwargs.get("noise", None)
+ ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
+ out['ofs'] = comfy.conds.CONDRegular(ofs)
+ return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 724a241bf..d9b67dcdf 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
+ if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX
+ dit_config = {}
+ dit_config["image_model"] = "cogvideox"
+
+ # Extract config from weight shapes
+ norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)]
+ time_embed_dim = norm1_weight.shape[1]
+ dim = norm1_weight.shape[0] // 6
+
+ dit_config["num_attention_heads"] = dim // 64
+ dit_config["attention_head_dim"] = 64
+ dit_config["time_embed_dim"] = time_embed_dim
+ dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
+
+ # Detect in_channels from patch_embed
+ patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix)
+ if patch_proj_key in state_dict_keys:
+ w = state_dict[patch_proj_key]
+ if w.ndim == 4:
+ # Conv2d: [out, in, kh, kw] — CogVideoX 1.0
+ dit_config["in_channels"] = w.shape[1]
+ dit_config["patch_size"] = w.shape[2]
+ elif w.ndim == 2:
+ # Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5
+ dit_config["patch_size"] = 2
+ dit_config["patch_size_t"] = 2
+ dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32
+
+ text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix)
+ if text_proj_key in state_dict_keys:
+ dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1]
+
+ # Detect OFS embedding
+ ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix)
+ if ofs_key in state_dict_keys:
+ dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1]
+
+ # Detect positional embedding type
+ pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
+ if pos_key in state_dict_keys:
+ dit_config["use_learned_positional_embeddings"] = True
+ dit_config["use_rotary_positional_embeddings"] = False
+ else:
+ dit_config["use_learned_positional_embeddings"] = False
+ dit_config["use_rotary_positional_embeddings"] = True
+
+ return dit_config
+
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 3b39d6080..21738a4c7 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management
import comfy.utils
import comfy.quant_ops
+import comfy_aimdo.vram_buffer
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
@@ -112,10 +113,6 @@ if args.directml is not None:
# torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
-try:
- import intel_extension_for_pytorch as ipex # noqa: F401
-except:
- pass
try:
_ = torch.xpu.device_count()
@@ -583,9 +580,6 @@ class LoadedModel:
real_model = self.model.model
- if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
- with torch.no_grad():
- real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
@@ -663,6 +657,7 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc()
+ comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
unloaded_model = []
can_unload = []
unloaded_models = []
@@ -726,13 +721,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
- models_temp = set()
+ # Order-preserving dedup. A plain set() would randomize iteration order across runs
+ models_temp = {}
for m in models:
- models_temp.add(m)
+ models_temp[m] = None
for mm in m.model_patches_models():
- models_temp.add(mm)
+ models_temp[mm] = None
- models = models_temp
+ models = list(models_temp)
+ models.reverse()
models_to_load = []
@@ -1181,6 +1178,10 @@ stream_counters = {}
STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
+STREAM_AIMDO_CAST_BUFFERS = {}
+LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
+
+DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT
@@ -1214,13 +1215,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
return cast_buffer
+def get_aimdo_cast_buffer(offload_stream, device):
+ cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
+ if cast_buffer is None:
+ cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
+ STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
+
+ return cast_buffer
def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT
+ global LARGEST_AIMDO_CASTED_WEIGHT
+
LARGEST_CASTED_WEIGHT = (None, 0)
- for offload_stream in STREAM_CAST_BUFFERS:
- offload_stream.synchronize()
+ LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
+ for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
+ if offload_stream is not None:
+ offload_stream.synchronize()
synchronize()
+
STREAM_CAST_BUFFERS.clear()
+ STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):
@@ -1580,10 +1594,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
- if torch_version_numeric < (2, 3):
- return True
- else:
- return torch.xpu.get_device_properties(device).has_fp16
+ return torch.xpu.get_device_properties(device).has_fp16
if is_ascend_npu():
return True
@@ -1649,10 +1660,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
if is_intel_xpu():
- if torch_version_numeric < (2, 3):
- return True
- else:
- return torch.xpu.is_bf16_supported()
+ return torch.xpu.is_bf16_supported()
if is_ascend_npu():
return True
@@ -1783,6 +1791,7 @@ def soft_empty_cache(force=False):
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()
elif is_intel_xpu():
+ torch.xpu.synchronize()
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index ee56f8523..33bdedfb1 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -26,11 +26,13 @@ import uuid
from typing import Callable, Optional
import torch
+import tqdm
import comfy.float
import comfy.hooks
import comfy.lora
import comfy.model_management
+import comfy.ops
import comfy.patcher_extension
import comfy.utils
from comfy.comfy_types import UnetWrapperFunction
@@ -120,9 +122,20 @@ class LowVramPatch:
self.patches = patches
self.convert_func = convert_func # TODO: remove
self.set_func = set_func
+ self.prepared_patches = None
+
+ def prepare(self, allocate_buffer, stream):
+ self.prepared_patches = [
+ (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
+ for patch in self.patches[self.key]
+ ]
+
+ def clear_prepared(self):
+ self.prepared_patches = None
def __call__(self, weight):
- return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
+ patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
+ return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
@@ -856,7 +869,9 @@ class ModelPatcher:
if m.comfy_patched_weights == True:
continue
- for param in params:
+ for param, param_value in params.items():
+ if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False):
+ comfy.ops.disable_weight_init._zero_init_parameter(m, param)
key = key_param_name_to_key(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
@@ -1637,7 +1652,11 @@ class ModelPatcherDynamic(ModelPatcher):
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
- logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
+ log_key = (self.patches_uuid, allocated_size, num_patches, len(self.backup), self.model.model_loaded_weight_memory)
+ in_loop = bool(getattr(tqdm.tqdm, "_instances", None))
+ level = logging.DEBUG if in_loop and getattr(self, "_last_prepare_log_key", None) == log_key else logging.INFO
+ self._last_prepare_log_key = log_key
+ logging.log(level, f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
self.model.device = device_to
self.model.current_weight_patches_uuid = self.patches_uuid
diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py
new file mode 100644
index 000000000..72e11dec6
--- /dev/null
+++ b/comfy/model_prefetch.py
@@ -0,0 +1,66 @@
+import comfy_aimdo.model_vbar
+import comfy.model_management
+import comfy.ops
+
+PREFETCH_QUEUES = []
+
+def cleanup_prefetched_modules(comfy_modules):
+ for s in comfy_modules:
+ prefetch = getattr(s, "_prefetch", None)
+ if prefetch is None:
+ continue
+ for param_key in ("weight", "bias"):
+ lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
+ if lowvram_fn is not None:
+ lowvram_fn.clear_prepared()
+ if prefetch["signature"] is not None:
+ comfy_aimdo.model_vbar.vbar_unpin(s._v)
+ delattr(s, "_prefetch")
+
+def cleanup_prefetch_queues():
+ global PREFETCH_QUEUES
+
+ for queue in PREFETCH_QUEUES:
+ for entry in queue:
+ if entry is None or not isinstance(entry, tuple):
+ continue
+ _, prefetch_state = entry
+ comfy_modules = prefetch_state[1]
+ if comfy_modules is not None:
+ cleanup_prefetched_modules(comfy_modules)
+ PREFETCH_QUEUES = []
+
+def prefetch_queue_pop(queue, device, module):
+ if queue is None:
+ return
+
+ consumed = queue.pop(0)
+ if consumed is not None:
+ offload_stream, prefetch_state = consumed
+ if offload_stream is not None:
+ offload_stream.wait_stream(comfy.model_management.current_stream(device))
+ _, comfy_modules = prefetch_state
+ if comfy_modules is not None:
+ cleanup_prefetched_modules(comfy_modules)
+
+ prefetch = queue[0]
+ if prefetch is not None:
+ comfy_modules = []
+ for s in prefetch.modules():
+ if hasattr(s, "_v"):
+ comfy_modules.append(s)
+
+ offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
+ comfy.model_management.sync_stream(device, offload_stream)
+ queue[0] = (offload_stream, (prefetch, comfy_modules))
+
+def make_prefetch_queue(queue, device, transformer_options):
+ if (not transformer_options.get("prefetch_dynamic_vbars", False)
+ or comfy.model_management.NUM_STREAMS == 0
+ or comfy.model_management.is_device_cpu(device)
+ or not comfy.model_management.device_supports_non_blocking(device)):
+ return None
+
+ queue = [None] + queue + [None]
+ PREFETCH_QUEUES.append(queue)
+ return queue
diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py
index 13860e6a2..cf2b5db5f 100644
--- a/comfy/model_sampling.py
+++ b/comfy/model_sampling.py
@@ -54,6 +54,30 @@ class V_PREDICTION(EPS):
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
+class V_PREDICTION_DDPM:
+ """CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v.
+ x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v
+ = x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1)
+ """
+ def calculate_input(self, sigma, noise):
+ return noise
+
+ def calculate_denoised(self, sigma, model_output, model_input):
+ sigma = reshape_sigma(sigma, model_output.ndim)
+ return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5
+
+ def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
+ sigma = reshape_sigma(sigma, noise.ndim)
+ if max_denoise:
+ noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
+ else:
+ noise = noise * sigma
+ noise += latent_image
+ return noise
+
+ def inverse_noise_scaling(self, sigma, latent):
+ return latent
+
class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
diff --git a/comfy/ops.py b/comfy/ops.py
index db5099b5e..966561b9e 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -79,37 +79,68 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
-def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
+def materialize_meta_param(s, param_keys):
+ for param_key in param_keys:
+ param = getattr(s, param_key, None)
+ if param is not None and getattr(param, "is_meta", False):
+ setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
- #vbar doesn't support CPU weights, but some custom nodes have weird paths
- #that might switch the layer to the CPU and expect it to work. We have to take
- #a clone conservatively as we are mmapped and some SFT files are packed misaligned
- #If you are a custom node author reading this, please move your layer to the GPU
- #or declare your ModelPatcher as CPU in the first place.
- if comfy.model_management.is_device_cpu(device):
- weight = s.weight.to(dtype=dtype, copy=True)
- if isinstance(weight, QuantizedTensor):
- weight = weight.dequantize()
- bias = None
- if s.bias is not None:
- bias = s.bias.to(dtype=bias_dtype, copy=True)
- return weight, bias, (None, None, None)
+# FIXME: add n=1 cache hit fast path
+def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
offload_stream = None
- xfer_dest = None
+ cast_buffer = None
+ cast_buffer_offset = 0
+
+ def ensure_offload_stream(module, required_size, check_largest):
+ nonlocal offload_stream
+ nonlocal cast_buffer
+
+ if offload_stream is None:
+ offload_stream = comfy.model_management.get_offload_stream(device)
+ if offload_stream is None or not check_largest or len(comfy_modules) != 1:
+ return
+
+ current_size = 0 if cast_buffer is None else cast_buffer.size()
+ if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
+ offload_stream = comfy.model_management.get_offload_stream(device)
+ cast_buffer = None
+ if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
+ comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
+
+ def get_cast_buffer(buffer_size):
+ nonlocal offload_stream
+ nonlocal cast_buffer
+ nonlocal cast_buffer_offset
+
+ if buffer_size == 0:
+ return None
+
+ if offload_stream is None:
+ return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
+
+ cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
+ buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
+ cast_buffer_offset += buffer_size
+ return buffer
+
+ for s in comfy_modules:
+ signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
+ resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
+ prefetch = {
+ "signature": signature,
+ "resident": resident,
+ }
- signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
- resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
- if signature is not None:
if resident:
- weight = s._v_weight
- bias = s._v_bias
- else:
- xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
+ s._prefetch = prefetch
+ continue
- if not resident:
+ materialize_meta_param(s, ["weight", "bias"])
+ xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
+ needs_cast = False
xfer_source = [ s.weight, s.bias ]
@@ -121,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None:
continue
if data.dtype != geometry.dtype:
+ needs_cast = True
cast_dest = xfer_dest
- if cast_dest is None:
- cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
- offload_stream = comfy.model_management.get_offload_stream(device)
- if xfer_dest is None and offload_stream is not None:
- xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
- if xfer_dest is None:
- offload_stream = comfy.model_management.get_offload_stream(device)
- xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
+ ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
if xfer_dest is None:
- xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
- offload_stream = None
+ xfer_dest = get_cast_buffer(dest_size)
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
@@ -149,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
- comfy.model_management.sync_stream(device, offload_stream)
- if cast_dest is not None:
+ for param_key in ("weight", "bias"):
+ lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
+ if lowvram_fn is not None:
+ ensure_offload_stream(s, cast_buffer_offset, False)
+ lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
+
+ prefetch["xfer_dest"] = xfer_dest
+ prefetch["cast_dest"] = cast_dest
+ prefetch["cast_geometry"] = cast_geometry
+ prefetch["needs_cast"] = needs_cast
+ s._prefetch = prefetch
+
+ return offload_stream
+
+
+def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
+
+ prefetch = getattr(s, "_prefetch", None)
+
+ if prefetch["resident"]:
+ weight = s._v_weight
+ bias = s._v_bias
+ else:
+ xfer_dest = prefetch["xfer_dest"]
+ if prefetch["needs_cast"]:
+ cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
- comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
+ comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
- params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
+ params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0]
bias = params[1]
- if signature is not None:
+ if prefetch["signature"] is not None:
s._v_weight = weight
s._v_bias = bias
- s._v_signature=signature
+ s._v_signature = prefetch["signature"]
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
+ if x is None:
+ return None
+
orig = x
def to_dequant(tensor, dtype):
@@ -197,14 +248,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x)
return x
- update_weight = signature is not None
+ update_weight = prefetch["signature"] is not None
+ weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
+ if bias is not None:
+ bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
- weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
- if s.bias is not None:
- bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
+ if prefetch["signature"] is not None:
+ prefetch["resident"] = True
- #FIXME: weird offload return protocol
- return weight, bias, (offload_stream, device if signature is not None else None, None)
+ return weight, bias
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@@ -222,10 +274,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
+ def format_return(result, offloadable):
+ weight, bias, offload_stream = result
+ return (weight, bias, offload_stream) if offloadable else (weight, bias)
+
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
- return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
+
+ #vbar doesn't support CPU weights, but some custom nodes have weird paths
+ #that might switch the layer to the CPU and expect it to work. We have to take
+ #a clone conservatively as we are mmapped and some SFT files are packed misaligned
+ #If you are a custom node author reading this, please move your layer to the GPU
+ #or declare your ModelPatcher as CPU in the first place.
+ if comfy.model_management.is_device_cpu(device):
+ materialize_meta_param(s, ["weight", "bias"])
+ weight = s.weight.to(dtype=dtype, copy=True)
+ if isinstance(weight, QuantizedTensor):
+ weight = weight.dequantize()
+ bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
+ return format_return((weight, bias, (None, None, None)), offloadable)
+
+ prefetched = hasattr(s, "_prefetch")
+ offload_stream = None
+ offload_device = None
+ if not prefetched:
+ offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
+ comfy.model_management.sync_stream(device, offload_stream)
+
+ weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
+
+ if not prefetched:
+ if getattr(s, "_prefetch")["signature"] is not None:
+ offload_device = device
+ for param_key in ("weight", "bias"):
+ lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
+ if lowvram_fn is not None:
+ lowvram_fn.clear_prepared()
+ delattr(s, "_prefetch")
+ return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
+
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@@ -272,11 +360,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function:
weight = f(weight)
- if offloadable:
- return weight, bias, (offload_stream, weight_a, bias_a)
- else:
- #Legacy function signature
- return weight, bias
+ return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
def uncast_bias_weight(s, weight, bias, offload_stream):
@@ -306,6 +390,12 @@ class CastWeightBiasOp:
bias_function = []
class disable_weight_init:
+ @staticmethod
+ def _zero_init_parameter(module, name):
+ param = getattr(module, name)
+ device = None if getattr(param, "is_meta", False) else param.device
+ setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False))
+
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,
@@ -472,6 +562,25 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
+ class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
+ def reset_parameters(self):
+ return None
+
+ def forward_comfy_cast_weights(self, input):
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
+ running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
+ running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
+ x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
+ uncast_bias_weight(self, weight, bias, offload_stream)
+ return x
+
+ def forward(self, *args, **kwargs):
+ run_every_op()
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
+
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@@ -659,6 +768,9 @@ class manual_cast(disable_weight_init):
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
+ class BatchNorm2d(disable_weight_init.BatchNorm2d):
+ comfy_cast_weights = True
+
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True
@@ -1205,6 +1317,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf)
return self
+ class Embedding(manual_cast.Embedding):
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys, error_msgs):
+ weight_key = f"{prefix}weight"
+ layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
+ if layer_conf is not None:
+ layer_conf = json.loads(layer_conf.numpy().tobytes())
+
+ # Only fp8 makes sense for embeddings (per-row dequant via index select).
+ # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
+ quant_format = layer_conf.get("format", None) if layer_conf is not None else None
+ if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
+ self.quant_format = quant_format
+ qconfig = QUANT_ALGOS[quant_format]
+ layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
+ weight = state_dict.pop(weight_key)
+ manually_loaded_keys = [weight_key]
+
+ scale_key = f"{prefix}weight_scale"
+ scale = state_dict.pop(scale_key, None)
+ if scale is not None:
+ scale = scale.float()
+ manually_loaded_keys.append(scale_key)
+
+ params = layout_cls.Params(
+ scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
+ orig_dtype=MixedPrecisionOps._compute_dtype,
+ orig_shape=(self.num_embeddings, self.embedding_dim),
+ )
+ self.weight = torch.nn.Parameter(
+ QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
+ requires_grad=False)
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+ for k in manually_loaded_keys:
+ if k in missing_keys:
+ missing_keys.remove(k)
+ else:
+ if layer_conf is not None:
+ state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+ def state_dict(self, *args, destination=None, prefix="", **kwargs):
+ if destination is not None:
+ sd = destination
+ else:
+ sd = {}
+
+ if not hasattr(self, 'weight') or self.weight is None:
+ return sd
+
+ if isinstance(self.weight, QuantizedTensor):
+ sd_out = self.weight.state_dict("{}weight".format(prefix))
+ for k in sd_out:
+ sd[k] = sd_out[k]
+
+ quant_conf = {"format": self.quant_format}
+ sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
+ else:
+ sd["{}weight".format(prefix)] = self.weight
+ return sd
+
+ def forward_comfy_cast_weights(self, input, out_dtype=None):
+ weight = self.weight
+
+ # Optimized path: lookup in fp8, dequantize only the selected rows.
+ if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
+ qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
+ if isinstance(qdata, QuantizedTensor):
+ scale = qdata._params.scale
+ qdata = qdata._qdata
+ else:
+ scale = None
+
+ x = torch.nn.functional.embedding(
+ input, qdata, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
+ uncast_bias_weight(self, qdata, None, offload_stream)
+ target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
+ x = x.to(dtype=target_dtype)
+ if scale is not None and scale != 1.0:
+ x = x * scale.to(dtype=target_dtype)
+ return x
+
+ # Fallback for non-quantized or weight_function (LoRA) case
+ return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
+
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py
index 6f142282d..6d3ba367a 100644
--- a/comfy/pinned_memory.py
+++ b/comfy/pinned_memory.py
@@ -2,7 +2,6 @@ import comfy.model_management
import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
-import psutil
from comfy.cli_args import args
@@ -12,11 +11,6 @@ def get_pin(module):
def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return
- #FIXME: This is a RAM cache trigger event
- ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
- #we split the difference and assume half the RAM cache headroom is for us
- if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
- comfy.memory_management.extra_ram_release(ram_headroom)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py
index 442447995..cae5f1180 100644
--- a/comfy/quant_ops.py
+++ b/comfy/quant_ops.py
@@ -1,6 +1,8 @@
import torch
import logging
+from comfy.cli_args import args
+
try:
import comfy_kitchen as ck
from comfy_kitchen.tensor import (
@@ -27,7 +29,15 @@ try:
"other kitchen CUDA ops (svdquant W4A4, fp8, mxfp8, rope) remain active.",
".".join(map(str, cuda_version)))
- ck.registry.disable("triton")
+ if args.enable_triton_backend:
+ try:
+ import triton
+ logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
+ except ImportError as e:
+ logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
+ ck.registry.disable("triton")
+ else:
+ ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e:
diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py
index ab7cf14fa..e54be98d6 100644
--- a/comfy/rmsnorm.py
+++ b/comfy/rmsnorm.py
@@ -3,6 +3,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm
+# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6):
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index bbba09e26..3782fd2d5 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -89,7 +89,8 @@ def get_additional_models(conds, dtype):
gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models")
- control_nets = set(cnets)
+ # Order-preserving dedup. A plain set() would randomize iteration order across runs
+ control_nets = list(dict.fromkeys(cnets))
inference_memory = 0
control_models = []
diff --git a/comfy/sd.py b/comfy/sd.py
index 736fe35de..749bdd710 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -18,6 +18,7 @@ import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
+import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
@@ -64,6 +65,8 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie
+import comfy.text_encoders.gemma4
+import comfy.text_encoders.cogvideo
import comfy.model_patcher
import comfy.lora
@@ -478,7 +481,10 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
- self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
+ if isinstance(metadata, dict) and "tae_latent_channels" in metadata:
+ self.latent_channels = metadata["tae_latent_channels"]
+ else:
+ self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
@@ -652,6 +658,17 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
+ elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
+ self.upscale_index_formula = (4, 8, 8)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
+ self.downscale_index_formula = (4, 8, 8)
+ self.latent_dim = 3
+ self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2
+ self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels)
+ self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
+ self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
+ self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@@ -1208,6 +1225,7 @@ class CLIPType(Enum):
NEWBIE = 24
FLUX2 = 25
LONGCAT_IMAGE = 26
+ COGVIDEOX = 27
@@ -1256,6 +1274,9 @@ class TEModel(Enum):
QWEN35_9B = 26
QWEN35_27B = 27
MINISTRAL_3_3B = 28
+ GEMMA_4_E4B = 29
+ GEMMA_4_E2B = 30
+ GEMMA_4_31B = 31
def detect_te_model(sd):
@@ -1281,6 +1302,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
+ if 'model.layers.59.self_attn.q_norm.weight' in sd:
+ return TEModel.GEMMA_4_31B
+ if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
+ return TEModel.GEMMA_4_E4B
+ if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
+ return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@@ -1403,6 +1430,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
+ elif clip_type == CLIPType.COGVIDEOX:
+ clip_target.clip = comfy.text_encoders.cogvideo.cogvideo_te(**t5xxl_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.cogvideo.CogVideoXTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
@@ -1420,6 +1450,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
+ elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
+ variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
+ TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
+ TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
+ clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
+ clip_target.tokenizer = variant.tokenizer
+ tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 8886f32d5..6a9613602 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -27,6 +27,7 @@ import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
+import comfy.text_encoders.cogvideo
from . import supported_models_base
from . import latent_formats
@@ -1166,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
+class WAN21_CausalAR_T2V(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "t2v",
+ "causal_ar": True,
+ }
+
+ sampling_settings = {
+ "shift": 5.0,
+ }
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+ self.unet_config.pop("causal_ar", None)
+
+ def get_model(self, state_dict, prefix="", device=None):
+ return model_base.WAN21_CausalAR(self, device=device)
+
+
class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@@ -1832,6 +1852,156 @@ class SAM31(SAM3):
unet_config = {"image_model": "SAM31"}
-models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
+class CogVideoX_T2V(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "cogvideox",
+ }
-models += [SVD_img2vid]
+ sampling_settings = {
+ "linear_start": 0.00085,
+ "linear_end": 0.012,
+ "beta_schedule": "linear",
+ "zsnr": True,
+ }
+
+ unet_extra_config = {}
+ latent_format = latent_formats.CogVideoX
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def __init__(self, unet_config):
+ # 2b-class (dim=1920, heads=30) uses scale_factor=1.15258426.
+ # 5b-class (dim=3072, heads=48) — incl. CogVideoX-5b, 1.5-5B, and
+ # Fun-V1.5 inpainting — uses scale_factor=0.7 per vae/config.json.
+ if unet_config.get("num_attention_heads", 0) >= 48:
+ self.latent_format = latent_formats.CogVideoX1_5
+ super().__init__(unet_config)
+
+ def get_model(self, state_dict, prefix="", device=None):
+ # CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
+ if self.unet_config.get("patch_size_t") is not None:
+ self.unet_config.setdefault("sample_height", 96)
+ self.unet_config.setdefault("sample_width", 170)
+ self.unet_config.setdefault("sample_frames", 81)
+ out = model_base.CogVideoX(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ return supported_models_base.ClipTarget(comfy.text_encoders.cogvideo.CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel)
+
+class CogVideoX_I2V(CogVideoX_T2V):
+ unet_config = {
+ "image_model": "cogvideox",
+ "in_channels": 32,
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ if self.unet_config.get("patch_size_t") is not None:
+ self.unet_config.setdefault("sample_height", 96)
+ self.unet_config.setdefault("sample_width", 170)
+ self.unet_config.setdefault("sample_frames", 81)
+ out = model_base.CogVideoX(self, image_to_video=True, device=device)
+ return out
+
+class CogVideoX_Inpaint(CogVideoX_T2V):
+ unet_config = {
+ "image_model": "cogvideox",
+ "in_channels": 48,
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ if self.unet_config.get("patch_size_t") is not None:
+ self.unet_config.setdefault("sample_height", 96)
+ self.unet_config.setdefault("sample_width", 170)
+ self.unet_config.setdefault("sample_frames", 81)
+ out = model_base.CogVideoX(self, image_to_video=True, device=device)
+ return out
+
+
+models = [
+ LotusD,
+ Stable_Zero123,
+ SD15_instructpix2pix,
+ SD15,
+ SD20,
+ SD21UnclipL,
+ SD21UnclipH,
+ SDXL_instructpix2pix,
+ SDXLRefiner,
+ SDXL,
+ SSD1B,
+ KOALA_700M,
+ KOALA_1B,
+ Segmind_Vega,
+ SD_X4Upscaler,
+ Stable_Cascade_C,
+ Stable_Cascade_B,
+ SV3D_u,
+ SV3D_p,
+ SD3,
+ StableAudio,
+ AuraFlow,
+ PixArtAlpha,
+ PixArtSigma,
+ HunyuanDiT,
+ HunyuanDiT1,
+ FluxInpaint,
+ Flux,
+ LongCatImage,
+ FluxSchnell,
+ GenmoMochi,
+ LTXV,
+ LTXAV,
+ HunyuanVideo15_SR_Distilled,
+ HunyuanVideo15,
+ HunyuanImage21Refiner,
+ HunyuanImage21,
+ HunyuanVideoSkyreelsI2V,
+ HunyuanVideoI2V,
+ HunyuanVideo,
+ CosmosT2V,
+ CosmosI2V,
+ CosmosT2IPredict2,
+ CosmosI2VPredict2,
+ ZImagePixelSpace,
+ ZImage,
+ Lumina2,
+ WAN22_T2V,
+ WAN21_CausalAR_T2V,
+ WAN21_T2V,
+ WAN21_I2V,
+ WAN21_FunControl2V,
+ WAN21_Vace,
+ WAN21_Camera,
+ WAN22_Camera,
+ WAN22_S2V,
+ WAN21_HuMo,
+ WAN22_Animate,
+ WAN21_FlowRVS,
+ WAN21_SCAIL,
+ Hunyuan3Dv2mini,
+ Hunyuan3Dv2,
+ Hunyuan3Dv2_1,
+ HiDream,
+ Chroma,
+ ChromaRadiance,
+ ACEStep,
+ ACEStep15,
+ Omnigen2,
+ QwenImage,
+ Flux2,
+ Kandinsky5Image,
+ Kandinsky5,
+ Anima,
+ RT_DETR_v4,
+ ErnieImage,
+ SAM3,
+ SAM31,
+ CogVideoX_Inpaint,
+ CogVideoX_I2V,
+ CogVideoX_T2V,
+ SVD_img2vid,
+]
diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py
index 6c06ce19d..696013200 100644
--- a/comfy/taesd/taehv.py
+++ b/comfy/taesd/taehv.py
@@ -7,6 +7,7 @@ from tqdm.auto import tqdm
from collections import namedtuple, deque
import comfy.ops
+import comfy.model_management
operations=comfy.ops.disable_weight_init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
@@ -47,11 +48,14 @@ class TGrow(nn.Module):
x = self.conv(x)
return x.reshape(-1, C, H, W)
-def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
+def apply_model_with_memblocks(model, x, parallel, show_progress_bar, output_device=None,
+ patch_size=1, decode=False):
B, T, C, H, W = x.shape
if parallel:
x = x.reshape(B*T, C, H, W)
+ if not decode and patch_size > 1:
+ x = F.pixel_unshuffle(x, patch_size)
# parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock):
@@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
x = b(x, mem)
else:
x = b(x)
- BT, C, H, W = x.shape
- T = BT // B
- x = x.view(B, T, C, H, W)
+ if decode and patch_size > 1:
+ x = F.pixel_shuffle(x, patch_size)
+ x = x.view(B, x.shape[0] // B, *x.shape[1:])
+ x = x.to(output_device)
else:
out = []
- work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
+ # Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views).
+ # Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode).
+ work_queue = deque([TWorkItem(xt.squeeze(1), 0) for xt in x.chunk(T, dim=1)])
progress_bar = tqdm(range(T), disable=not show_progress_bar)
mem = [None] * len(model)
while work_queue:
xt, i = work_queue.popleft()
if i == 0:
progress_bar.update(1)
+ if not decode and patch_size > 1:
+ xt = F.pixel_unshuffle(xt, patch_size)
if i == len(model):
- out.append(xt)
+ if decode and patch_size > 1:
+ xt = F.pixel_shuffle(xt, patch_size)
+ out.append(xt.to(output_device))
del xt
else:
b = model[i]
@@ -165,24 +176,20 @@ class TAEHV(nn.Module):
def encode(self, x, **kwargs):
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
- if self.patch_size > 1:
- B, T, C, H, W = x.shape
- x = x.reshape(B * T, C, H, W)
- x = F.pixel_unshuffle(x, self.patch_size)
- x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
if x.shape[1] % self.t_downscale != 0:
# pad at end to multiple of t_downscale
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1)
- x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
+ x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar,
+ patch_size=self.patch_size).movedim(2, 1)
return self.process_out(x)
def decode(self, x, **kwargs):
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
- x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
- if self.patch_size > 1:
- x = F.pixel_shuffle(x, self.patch_size)
+ x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar,
+ output_device=comfy.model_management.intermediate_device(),
+ patch_size=self.patch_size, decode=True)
return x[:, self.frames_to_trim:].movedim(2, 1)
diff --git a/comfy/taesd/taesd.py b/comfy/taesd/taesd.py
index ce36f1a84..05d370209 100644
--- a/comfy/taesd/taesd.py
+++ b/comfy/taesd/taesd.py
@@ -17,32 +17,79 @@ class Clamp(nn.Module):
return torch.tanh(x / 3) * 3
class Block(nn.Module):
- def __init__(self, n_in, n_out):
+ def __init__(self, n_in: int, n_out: int, use_midblock_gn: bool = False):
super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU()
- def forward(self, x):
+ if not use_midblock_gn:
+ self.pool = None
+ return
+ n_gn = n_in * 4
+ self.pool = nn.Sequential(
+ comfy.ops.disable_weight_init.Conv2d(n_in, n_gn, 1, bias=False),
+ comfy.ops.disable_weight_init.GroupNorm(4, n_gn),
+ nn.ReLU(inplace=True),
+ comfy.ops.disable_weight_init.Conv2d(n_gn, n_in, 1, bias=False),
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.pool is not None:
+ x = x + self.pool(x)
return self.fuse(self.conv(x) + self.skip(x))
-def Encoder(latent_channels=4):
- return nn.Sequential(
- conv(3, 64), Block(64, 64),
- conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
- conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
- conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
- conv(64, latent_channels),
- )
+class Encoder(nn.Sequential):
+ def __init__(self, latent_channels: int = 4, use_gn: bool = False):
+ super().__init__(
+ conv(3, 64), Block(64, 64),
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+ conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
+ conv(64, 64, stride=2, bias=False), Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn),
+ conv(64, latent_channels),
+ )
+class Decoder(nn.Sequential):
+ def __init__(self, latent_channels: int = 4, use_gn: bool = False):
+ super().__init__(
+ Clamp(), conv(latent_channels, 64), nn.ReLU(),
+ Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
+ Block(64, 64), conv(64, 3),
+ )
+
+class DecoderFlux2(Decoder):
+ def __init__(self, latent_channels: int = 128, use_gn: bool = True):
+ if latent_channels != 128 or not use_gn:
+ raise ValueError("Unexpected parameters for Flux2 TAE module")
+ super().__init__(latent_channels=32, use_gn=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, H, W = x.shape
+ x = (
+ x
+ .reshape(B, 32, 2, 2, H, W)
+ .permute(0, 1, 4, 2, 5, 3)
+ .reshape(B, 32, H * 2, W * 2)
+ )
+ return super().forward(x)
+
+class EncoderFlux2(Encoder):
+ def __init__(self, latent_channels: int = 128, use_gn: bool = True):
+ if latent_channels != 128 or not use_gn:
+ raise ValueError("Unexpected parameters for Flux2 TAE module")
+ super().__init__(latent_channels=32, use_gn=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ result = super().forward(x)
+ B, C, H, W = result.shape
+ return (
+ result
+ .reshape(B, C, H // 2, 2, W // 2, 2)
+ .permute(0, 1, 3, 5, 2, 4)
+ .reshape(B, 128, H // 2, W // 2)
+ )
-def Decoder(latent_channels=4):
- return nn.Sequential(
- Clamp(), conv(latent_channels, 64), nn.ReLU(),
- Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
- Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
- Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
- Block(64, 64), conv(64, 3),
- )
class TAESD(nn.Module):
latent_magnitude = 3
@@ -51,8 +98,15 @@ class TAESD(nn.Module):
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
"""Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__()
- self.taesd_encoder = Encoder(latent_channels=latent_channels)
- self.taesd_decoder = Decoder(latent_channels=latent_channels)
+ if latent_channels == 128:
+ encoder_class = EncoderFlux2
+ decoder_class = DecoderFlux2
+ else:
+ encoder_class = Encoder
+ decoder_class = Decoder
+ self.taesd_encoder = encoder_class(latent_channels=latent_channels)
+ self.taesd_decoder = decoder_class(latent_channels=latent_channels)
+
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None:
@@ -61,19 +115,19 @@ class TAESD(nn.Module):
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod
- def scale_latents(x):
+ def scale_latents(x: torch.Tensor) -> torch.Tensor:
"""raw latents -> [0, 1]"""
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
@staticmethod
- def unscale_latents(x):
+ def unscale_latents(x: torch.Tensor) -> torch.Tensor:
"""[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
- def decode(self, x):
+ def decode(self, x: torch.Tensor) -> torch.Tensor:
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2)
return x_sample
- def encode(self, x):
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift
diff --git a/comfy/text_encoders/cogvideo.py b/comfy/text_encoders/cogvideo.py
new file mode 100644
index 000000000..b97310709
--- /dev/null
+++ b/comfy/text_encoders/cogvideo.py
@@ -0,0 +1,48 @@
+import comfy.text_encoders.sd3_clip
+from comfy import sd1_clip
+
+
+class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
+ """Inner T5 tokenizer for CogVideoX.
+
+ CogVideoX was trained with T5 embeddings padded to 226 tokens (not 77 like SD3).
+ Used both directly by supported_models.CogVideoX_T2V.clip_target (paired with
+ the raw T5XXLModel) and by the CogVideoXTokenizer outer wrapper below.
+ """
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)
+
+
+class CogVideoXTokenizer(sd1_clip.SD1Tokenizer):
+ """Outer tokenizer wrapper for CLIPLoader (type="cogvideox")."""
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
+ clip_name="t5xxl", tokenizer=CogVideoXT5Tokenizer)
+
+
+class CogVideoXT5XXL(sd1_clip.SD1ClipModel):
+ """Outer T5XXL model wrapper for CLIPLoader (type="cogvideox").
+
+ Wraps the raw T5XXL model in the SD1ClipModel interface so that CLIP.__init__
+ (which reads self.dtypes) works correctly. The inner model is the standard
+ sd3_clip.T5XXLModel (no attention_mask change needed for CogVideoX).
+ """
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ super().__init__(device=device, dtype=dtype, name="t5xxl",
+ clip_model=comfy.text_encoders.sd3_clip.T5XXLModel,
+ model_options=model_options)
+
+
+def cogvideo_te(dtype_t5=None, t5_quantization_metadata=None):
+ """Factory that returns a CogVideoXT5XXL class configured with the detected
+ T5 dtype and optional quantization metadata, for use in load_text_encoder_state_dicts.
+ """
+ class CogVideoXTEModel_(CogVideoXT5XXL):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if t5_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
+ if dtype_t5 is not None:
+ dtype = dtype_t5
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return CogVideoXTEModel_
diff --git a/comfy/text_encoders/gemma4.py b/comfy/text_encoders/gemma4.py
new file mode 100644
index 000000000..f050061ed
--- /dev/null
+++ b/comfy/text_encoders/gemma4.py
@@ -0,0 +1,1298 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from dataclasses import dataclass
+import math
+
+from comfy import sd1_clip
+import comfy.model_management
+from comfy.ldm.modules.attention import optimized_attention_for_device
+from comfy.rmsnorm import rms_norm
+from comfy.text_encoders.llama import RMSNorm, MLP, BaseLlama, BaseGenerate, _make_scaled_embedding
+
+
+# Intentional minor divergences from transformers -reference implementation:
+# - Embedding sqrt(hidden_size) scale applied as a Python scalar (full precision) instead of dtype-matched buffer tensor.
+# - RMSNorm uses torch fused F.rms_norm, very slight numerical differences, but considerably faster
+# - Input image and audio resizing/resampling slightly different numerically
+
+
+GEMMA4_VISION_CONFIG = {"hidden_size": 768, "image_size": 896, "intermediate_size": 3072, "num_attention_heads": 12, "num_hidden_layers": 16, "patch_size": 16, "head_dim": 64, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
+GEMMA4_VISION_31B_CONFIG = {"hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 16, "head_dim": 72, "rms_norm_eps": 1e-6, "position_embedding_size": 10240, "pooling_kernel_size": 3}
+GEMMA4_AUDIO_CONFIG = {"hidden_size": 1024, "num_hidden_layers": 12, "num_attention_heads": 8, "intermediate_size": 4096, "conv_kernel_size": 5, "attention_chunk_size": 12, "attention_context_left": 13, "attention_context_right": 0, "attention_logit_cap": 50.0, "output_proj_dims": 1536, "rms_norm_eps": 1e-6, "residual_weight": 0.5}
+
+@dataclass
+class Gemma4Config:
+ vocab_size: int = 262144
+ hidden_size: int = 2560
+ intermediate_size: int = 10240
+ num_hidden_layers: int = 42
+ num_attention_heads: int = 8
+ num_key_value_heads: int = 2
+ max_position_embeddings: int = 131072
+ rms_norm_eps: float = 1e-6
+ rope_theta = [1000000.0, 10000.0]
+ transformer_type: str = "gemma4"
+ head_dim = 256
+ global_head_dim = 512
+ rms_norm_add = False
+ mlp_activation = "gelu_pytorch_tanh"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ sliding_attention = [512, 512, 512, 512, 512, False]
+ rope_scale = None
+ partial_rotary_factor: float = 0.25
+ final_norm: bool = True
+ lm_head: bool = False
+ final_logit_softcapping: float = 30.0
+ hidden_size_per_layer_input: int = 256
+ num_kv_shared_layers: int = 18
+ use_double_wide_mlp: bool = False
+ stop_tokens = [1, 50, 106]
+ vision_config = GEMMA4_VISION_CONFIG
+ audio_config = GEMMA4_AUDIO_CONFIG
+ mm_tokens_per_image = 280
+
+@dataclass
+class Gemma4_E2B_Config(Gemma4Config):
+ hidden_size: int = 1536
+ intermediate_size: int = 6144
+ num_hidden_layers: int = 35
+ num_key_value_heads: int = 1
+ sliding_attention = [512, 512, 512, 512, False]
+ num_kv_shared_layers: int = 20
+ use_double_wide_mlp: bool = True
+
+@dataclass
+class Gemma4_31B_Config(Gemma4Config):
+ hidden_size: int = 5376
+ intermediate_size: int = 21504
+ num_hidden_layers: int = 60
+ num_attention_heads: int = 32
+ num_key_value_heads: int = 16
+ sliding_attention = [1024, 1024, 1024, 1024, 1024, False]
+ hidden_size_per_layer_input: int = 0
+ num_kv_shared_layers: int = 0
+ audio_config = None
+ vision_config = GEMMA4_VISION_31B_CONFIG
+
+
+# unfused RoPE as addcmul_ RoPE diverges from reference code
+def _apply_rotary_pos_emb(x, freqs_cis):
+ cos, sin = freqs_cis[0], freqs_cis[1]
+ half = x.shape[-1] // 2
+ out = x * cos
+ out[..., :half] -= x[..., half:] * sin[..., :half]
+ out[..., half:] += x[..., :half] * sin[..., half:]
+ return out
+
+class Gemma4Attention(nn.Module):
+ def __init__(self, config, head_dim, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.num_heads = config.num_attention_heads
+ self.num_kv_heads = config.num_key_value_heads
+ self.hidden_size = config.hidden_size
+ self.head_dim = head_dim
+ self.inner_size = self.num_heads * head_dim
+
+ self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=config.qkv_bias, device=device, dtype=dtype)
+ self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
+ self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * head_dim, bias=config.qkv_bias, device=device, dtype=dtype)
+ self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
+
+ self.q_norm = None
+ self.k_norm = None
+ if config.q_norm == "gemma3":
+ self.q_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ if config.k_norm == "gemma3":
+ self.k_norm = RMSNorm(head_dim, eps=config.rms_norm_eps, device=device, dtype=dtype)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask=None,
+ freqs_cis=None,
+ past_key_value=None,
+ sliding_window=None,
+ shared_kv=None,
+ ):
+ batch_size, seq_length, _ = hidden_states.shape
+
+ xq = self.q_proj(hidden_states)
+ xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
+ if self.q_norm is not None:
+ xq = self.q_norm(xq)
+
+ if shared_kv is not None:
+ xk, xv = shared_kv
+ # Apply RoPE to Q only (K already has RoPE from source layer)
+ xq = _apply_rotary_pos_emb(xq, freqs_cis)
+ present_key_value = None
+ shareable_kv = None
+ else:
+ xk = self.k_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
+ xv = self.v_proj(hidden_states).view(batch_size, seq_length, self.num_kv_heads, self.head_dim)
+ if self.k_norm is not None:
+ xk = self.k_norm(xk)
+ xv = rms_norm(xv)
+ xk = xk.transpose(1, 2)
+ xv = xv.transpose(1, 2)
+ xq = _apply_rotary_pos_emb(xq, freqs_cis)
+ xk = _apply_rotary_pos_emb(xk, freqs_cis)
+
+ present_key_value = None
+ if past_key_value is not None:
+ cumulative_len = 0
+ if len(past_key_value) > 0:
+ past_key, past_value, cumulative_len = past_key_value
+ xk = torch.cat((past_key, xk), dim=2)
+ xv = torch.cat((past_value, xv), dim=2)
+ new_cumulative = cumulative_len + seq_length
+ if sliding_window is not None and xk.shape[2] > sliding_window - 1:
+ cache_k = xk[:, :, -(sliding_window - 1):]
+ cache_v = xv[:, :, -(sliding_window - 1):]
+ else:
+ cache_k = xk
+ cache_v = xv
+ present_key_value = (cache_k, cache_v, new_cumulative)
+
+ # KV for sharing: full xk/xv that SDPA sees (not evicted cache)
+ shareable_kv = (xk, xv)
+
+ # GQA: pass unexpanded KV with enable_gqa when no sliding mask,
+ # expand heads when sliding mask is present
+ # has to be done within SDPA itself to match the reference code, pre-scaling expansion causes numerical differences
+ expand_kv = (self.num_heads != self.num_kv_heads and
+ sliding_window is not None and
+ xk.shape[2] >= sliding_window)
+ if expand_kv:
+ xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
+ xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
+ gqa_kwargs = {} if expand_kv else ({"enable_gqa": True} if self.num_heads != self.num_kv_heads else {})
+ output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0, **gqa_kwargs)
+
+ return self.o_proj(output), present_key_value, shareable_kv
+
+
+class TransformerBlockGemma4(nn.Module):
+ def __init__(self, config, index, device=None, dtype=None, ops=None):
+ super().__init__()
+ if config.sliding_attention is not None:
+ self.sliding_attention = config.sliding_attention[index % len(config.sliding_attention)]
+ else:
+ self.sliding_attention = False
+
+ head_dim = config.head_dim if self.sliding_attention else config.global_head_dim
+
+ self.self_attn = Gemma4Attention(config, head_dim=head_dim, device=device, dtype=dtype, ops=ops)
+
+ num_kv_shared = config.num_kv_shared_layers
+ first_kv_shared = config.num_hidden_layers - num_kv_shared
+ mlp_size = config.intermediate_size * 2 if config.use_double_wide_mlp and index >= first_kv_shared else None
+ self.mlp = MLP(config, device=device, dtype=dtype, ops=ops, intermediate_size=mlp_size)
+
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+ if self.hidden_size_per_layer_input:
+ self.per_layer_input_gate = ops.Linear(config.hidden_size, self.hidden_size_per_layer_input, bias=False, device=device, dtype=dtype)
+ self.per_layer_projection = ops.Linear(self.hidden_size_per_layer_input, config.hidden_size, bias=False, device=device, dtype=dtype)
+ self.post_per_layer_input_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.register_buffer("layer_scalar", torch.ones(1, device=device, dtype=dtype))
+ else:
+ self.layer_scalar = None
+
+ def forward(self, x, attention_mask=None, freqs_cis=None, past_key_value=None, per_layer_input=None, shared_kv=None):
+ sliding_window = None
+ if self.sliding_attention:
+ sliding_window = self.sliding_attention
+ # For prefill > sliding window, add sliding window restriction to the causal mask.
+ if x.shape[1] > self.sliding_attention:
+ sw_mask = torch.zeros(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
+ sw_mask.masked_fill_(torch.ones_like(sw_mask, dtype=torch.bool).tril_(-self.sliding_attention), torch.finfo(x.dtype).min)
+ attention_mask = attention_mask + sw_mask if attention_mask is not None else sw_mask
+ freqs_cis = freqs_cis[1]
+ else:
+ freqs_cis = freqs_cis[0]
+
+ residual = x
+ x = self.input_layernorm(x)
+ x, present_key_value, shareable_kv = self.self_attn(
+ hidden_states=x, attention_mask=attention_mask, freqs_cis=freqs_cis,
+ past_key_value=past_key_value, sliding_window=sliding_window, shared_kv=shared_kv,
+ )
+ x = self.post_attention_layernorm(x)
+ x = residual + x
+
+ residual = x
+ x = self.pre_feedforward_layernorm(x)
+ x = self.mlp(x)
+ x = self.post_feedforward_layernorm(x)
+ x = residual + x
+
+ if self.hidden_size_per_layer_input and per_layer_input is not None:
+ residual = x
+ x = self.per_layer_input_gate(x)
+ x = torch.nn.functional.gelu(x, approximate="tanh")
+ x = x * per_layer_input
+ x = self.per_layer_projection(x)
+ x = self.post_per_layer_input_norm(x)
+ x = residual + x
+
+ if self.layer_scalar is not None:
+ x = x * self.layer_scalar
+
+ return x, present_key_value, shareable_kv
+
+
+class Gemma4Transformer(nn.Module):
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.config = config
+
+ self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
+
+ self.layers = nn.ModuleList([
+ TransformerBlockGemma4(config, index=i, device=device, dtype=dtype, ops=ops)
+ for i in range(config.num_hidden_layers)
+ ])
+
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) if config.final_norm else None
+
+ # Precompute RoPE inv_freq on CPU to match reference code's exact value
+ rope_angles_global = int(config.partial_rotary_factor * config.global_head_dim // 2)
+ nope_global = config.global_head_dim // 2 - rope_angles_global
+ global_inv = 1.0 / (config.rope_theta[0] ** (torch.arange(0, 2 * rope_angles_global, 2).float() / config.global_head_dim))
+ if nope_global > 0:
+ global_inv = torch.cat([global_inv, torch.zeros(nope_global)])
+ self.register_buffer("_global_inv_freq", global_inv, persistent=False)
+
+ sliding_inv = 1.0 / (config.rope_theta[1] ** (torch.arange(0, config.head_dim, 2).float() / config.head_dim))
+ self.register_buffer("_sliding_inv_freq", sliding_inv, persistent=False)
+
+ # Per-layer input mechanism
+ self.hidden_size_per_layer_input = config.hidden_size_per_layer_input
+ if self.hidden_size_per_layer_input:
+ self.embed_tokens_per_layer = _make_scaled_embedding(ops, config.vocab_size, config.num_hidden_layers * self.hidden_size_per_layer_input, self.hidden_size_per_layer_input ** 0.5, device, dtype)
+ self.per_layer_model_projection = ops.Linear(
+ config.hidden_size, config.num_hidden_layers * self.hidden_size_per_layer_input,
+ bias=False, device=device, dtype=dtype)
+ self.per_layer_projection_norm = RMSNorm(
+ self.hidden_size_per_layer_input, eps=config.rms_norm_eps,
+ device=device, dtype=dtype)
+
+ def get_past_len(self, past_key_values):
+ for kv in past_key_values:
+ if len(kv) >= 3:
+ return kv[2]
+ return 0
+
+ def _freqs_from_inv(self, inv_freq, position_ids, device, dtype):
+ """Compute cos/sin from stored inv_freq"""
+ inv_exp = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(device)
+ pos_exp = position_ids[:, None, :].float()
+ freqs = (inv_exp @ pos_exp).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ return emb.cos().unsqueeze(1).to(dtype), emb.sin().unsqueeze(1).to(dtype)
+
+ def compute_freqs_cis(self, position_ids, device, dtype=None):
+ global_freqs = self._freqs_from_inv(self._global_inv_freq, position_ids, device, dtype)
+ sliding_freqs = self._freqs_from_inv(self._sliding_inv_freq, position_ids, device, dtype)
+ return [global_freqs, sliding_freqs]
+
+ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None,
+ final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=None,
+ past_key_values=None, input_ids=None):
+ if embeds is not None:
+ x = embeds
+ else:
+ x = self.embed_tokens(x, out_dtype=dtype)
+
+ seq_len = x.shape[1]
+ past_len = 0
+ if past_key_values is not None and len(past_key_values) > 0:
+ past_len = self.get_past_len(past_key_values)
+
+ if position_ids is None:
+ position_ids = torch.arange(past_len, past_len + seq_len, device=x.device).unsqueeze(0)
+
+ freqs_cis = self.compute_freqs_cis(position_ids, x.device, dtype=x.dtype)
+
+ mask = None
+ min_val = torch.finfo(x.dtype).min
+ if attention_mask is not None:
+ mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1])
+ mask = mask.masked_fill(mask.to(torch.bool), min_val)
+
+ if seq_len > 1:
+ causal_mask = torch.zeros(past_len + seq_len, past_len + seq_len, dtype=x.dtype, device=x.device)
+ causal_mask.masked_fill_(torch.ones_like(causal_mask, dtype=torch.bool).triu_(1), min_val)
+ mask = mask + causal_mask if mask is not None else causal_mask
+
+ # Per-layer inputs
+ per_layer_inputs = None
+ if self.hidden_size_per_layer_input:
+ num_layers = self.config.num_hidden_layers
+ hpl = self.hidden_size_per_layer_input
+ per_layer_proj = self.per_layer_model_projection(x) * (1.0 / (self.config.hidden_size ** 0.5))
+ per_layer_proj = self.per_layer_projection_norm(per_layer_proj.reshape(*x.shape[:-1], num_layers, hpl))
+ if input_ids is not None and input_ids.shape[1] == x.shape[1]:
+ per_layer_emb = self.embed_tokens_per_layer(input_ids).reshape(*input_ids.shape, num_layers, hpl)
+ per_layer_inputs = (per_layer_proj + per_layer_emb) * (0.5 ** 0.5)
+ else:
+ per_layer_inputs = per_layer_proj
+
+ # KV sharing: later layers reuse KV from the last non-shared sliding/global layer
+ num_kv_shared = self.config.num_kv_shared_layers
+ first_kv_shared = self.config.num_hidden_layers - num_kv_shared if num_kv_shared > 0 else self.config.num_hidden_layers
+ shared_sliding_kv = None # KV from last non-shared sliding layer
+ shared_global_kv = None # KV from last non-shared global layer
+
+ intermediate = None
+ next_key_values = []
+ for i, layer in enumerate(self.layers):
+ past_kv = past_key_values[i] if past_key_values is not None and len(past_key_values) > 0 else None
+
+ layer_kwargs = {}
+ if per_layer_inputs is not None:
+ layer_kwargs['per_layer_input'] = per_layer_inputs[:, :, i, :]
+
+ is_sliding = hasattr(layer, 'sliding_attention') and layer.sliding_attention
+ if i >= first_kv_shared and num_kv_shared > 0:
+ shared = shared_sliding_kv if is_sliding else shared_global_kv
+ if shared is not None:
+ layer_kwargs['shared_kv'] = shared
+
+ x, current_kv, shareable_kv = layer(x=x, attention_mask=mask, freqs_cis=freqs_cis, past_key_value=past_kv, **layer_kwargs)
+
+ next_key_values.append(current_kv if current_kv is not None else ())
+
+ # Only track the last sliding/global before the sharing boundary
+ if i < first_kv_shared and shareable_kv is not None:
+ if is_sliding:
+ shared_sliding_kv = shareable_kv
+ else:
+ shared_global_kv = shareable_kv
+
+ if i == intermediate_output:
+ intermediate = x.clone()
+
+ if self.norm is not None:
+ x = self.norm(x)
+
+ if len(next_key_values) > 0:
+ return x, intermediate, next_key_values
+ return x, intermediate
+
+
+class Gemma4Base(BaseLlama, BaseGenerate, torch.nn.Module):
+ """Common base for all Gemma4 variants: text model + vision."""
+ def _init_model(self, config, dtype, device, operations):
+ self.num_layers = config.num_hidden_layers
+ self.model = Gemma4Transformer(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+ self.multi_modal_projector = Gemma4MultiModalProjector(config, dtype=dtype, device=device, ops=operations)
+ self.vision_model = Gemma4VisionEncoder(config.vision_config, dtype=dtype, device=device, ops=operations)
+
+ def logits(self, x):
+ logits = super().logits(x)
+ cap = self.model.config.final_logit_softcapping
+ if cap:
+ logits = cap * torch.tanh(logits / cap)
+ return logits
+
+ def init_kv_cache(self, batch, max_cache_len, device, execution_dtype):
+ past_key_values = []
+ for _ in range(self.model.config.num_hidden_layers):
+ past_key_values.append(())
+ return past_key_values
+
+ def preprocess_embed(self, embed, device):
+ if embed["type"] == "image":
+ image = embed.pop("data").movedim(-1, 1) # [B, H, W, C] -> [B, C, H, W]
+ max_soft_tokens = embed.get("max_soft_tokens", None)
+ vision_out = self.vision_model(image.to(device, dtype=torch.float32), max_soft_tokens=max_soft_tokens)
+ return self.multi_modal_projector(vision_out), None
+ return None, None
+
+
+class Gemma4AudioMixin:
+ """Adds audio support to a Gemma4 model."""
+ def _init_audio(self, config, dtype, device, operations):
+ self.audio_model = Gemma4AudioEncoder(config.audio_config, dtype=dtype, device=device, ops=operations)
+ self.audio_projector = Gemma4AudioProjector({"audio_output_proj_dims": config.audio_config["output_proj_dims"], "text_hidden_size": config.hidden_size, "rms_norm_eps": config.rms_norm_eps}, dtype=dtype, device=device, ops=operations)
+
+ def preprocess_embed(self, embed, device):
+ result, extra = super().preprocess_embed(embed, device)
+ if result is not None:
+ return result, extra
+ if embed["type"] == "audio":
+ audio = embed.pop("data").to(device, dtype=torch.float32)
+ audio_mask = embed.pop("mask", None)
+ if audio_mask is not None:
+ audio_mask = audio_mask.to(device)
+ audio_out = self.audio_model(audio, audio_mask=audio_mask)
+ return self.audio_projector(audio_out), None
+ return None, None
+
+
+# Vision Encoder
+
+def _compute_vision_2d_rope(head_dim, pixel_position_ids, theta=100.0, device=None):
+ """Compute 2D RoPE for vision: separate frequencies for x and y dimensions.
+
+ Args:
+ head_dim: dimension per head (e.g. 64)
+ pixel_position_ids: [batch, num_patches, 2] with (x, y) coords
+ theta: RoPE base frequency
+ Returns:
+ (cos, sin) each of shape [batch, num_patches, head_dim]
+ """
+ rotary_dim_per_axis = head_dim // 2
+ freq_indices = torch.arange(0, rotary_dim_per_axis, 2, device=device).float()
+ inv_freq = 1.0 / (theta ** (freq_indices / rotary_dim_per_axis))
+
+ all_cos, all_sin = [], []
+ for i in range(2): # x and y
+ dim_positions = pixel_position_ids[:, :, i].float() # [batch, num_patches]
+ freqs = torch.einsum('bi,j->bij', dim_positions, inv_freq.to(device)) # [batch, num_patches, rotary_dim/2]
+ emb = torch.cat([freqs, freqs], dim=-1) # [batch, num_patches, rotary_dim]
+ all_cos.append(emb.cos())
+ all_sin.append(emb.sin())
+
+ cos = torch.cat(all_cos, dim=-1).to(pixel_position_ids.device) # [batch, num_patches, head_dim]
+ sin = torch.cat(all_sin, dim=-1).to(pixel_position_ids.device)
+ return cos, sin
+
+
+def _apply_vision_2d_rope(x, freqs):
+ """Apply 2D RoPE (multidimensional) to vision query/key states.
+
+ Splits x and cos/sin into ndim=2 parts, applies 1D RoPE to each independently.
+
+ x: [batch, heads, seq, head_dim]
+ freqs: (cos, sin) each [batch, seq, head_dim]
+ """
+ cos = freqs[0].unsqueeze(1) # [batch, 1, seq, head_dim]
+ sin = freqs[1].unsqueeze(1)
+ half = x.shape[-1] // 2
+ a = _apply_rotary_pos_emb(x[..., :half], (cos[..., :half], sin[..., :half]))
+ b = _apply_rotary_pos_emb(x[..., half:], (cos[..., half:], sin[..., half:]))
+ return torch.cat([a, b], dim=-1)
+
+
+class ClippedLinear(nn.Module):
+ """Linear layer with activation clipping (from quantization-aware training).
+
+ Stores input_max/min and output_max/min as buffers loaded from checkpoint.
+ """
+ def __init__(self, in_features, out_features, bias=False, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.linear = ops.Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
+ self.register_buffer('input_max', torch.tensor(float('inf'), device=device, dtype=dtype))
+ self.register_buffer('input_min', torch.tensor(float('-inf'), device=device, dtype=dtype))
+ self.register_buffer('output_max', torch.tensor(float('inf'), device=device, dtype=dtype))
+ self.register_buffer('output_min', torch.tensor(float('-inf'), device=device, dtype=dtype))
+
+ @property
+ def weight(self):
+ return self.linear.weight
+
+ def forward(self, x):
+ x = x.clamp(min=self.input_min, max=self.input_max)
+ x = self.linear(x)
+ return x.clamp_(min=self.output_min, max=self.output_max)
+
+
+class Gemma4VisionMLP(nn.Module):
+ """SwiGLU MLP matching gate_proj/up_proj/down_proj structure."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ intermediate_size = config["intermediate_size"]
+ self.gate_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
+ self.up_proj = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
+ self.down_proj = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
+
+ def forward(self, x):
+ return self.down_proj(torch.nn.functional.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))
+
+
+class Gemma4VisionAttention(nn.Module):
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.hidden_size = config["hidden_size"]
+ self.num_heads = config["num_attention_heads"]
+ self.head_dim = config.get("head_dim", self.hidden_size // self.num_heads)
+
+ self.q_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
+ self.k_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
+ self.v_proj = ClippedLinear(self.hidden_size, self.num_heads * self.head_dim, device=device, dtype=dtype, ops=ops)
+ self.o_proj = ClippedLinear(self.num_heads * self.head_dim, self.hidden_size, device=device, dtype=dtype, ops=ops)
+
+ self.q_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.k_norm = RMSNorm(self.head_dim, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+
+ def forward(self, x, freqs, attention_mask=None):
+ batch_size, seq_length, _ = x.shape
+
+ xq = self.q_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
+ xk = self.k_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
+ xv = self.v_proj(x).view(batch_size, seq_length, self.num_heads, self.head_dim)
+
+ xq = self.q_norm(xq).transpose(1, 2)
+ xk = self.k_norm(xk).transpose(1, 2)
+ xv = rms_norm(xv)
+
+ xq = _apply_vision_2d_rope(xq, freqs)
+ xk = _apply_vision_2d_rope(xk, freqs)
+
+ xv = xv.to(xq.dtype).transpose(1, 2)
+
+ output = optimized_attention_for_device(xq.device, mask=attention_mask is not None, small_input=True)(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True, scale=1.0)
+ return self.o_proj(output)
+
+
+class Gemma4VisionLayer(nn.Module):
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.self_attn = Gemma4VisionAttention(config, device=device, dtype=dtype, ops=ops)
+ self.mlp = Gemma4VisionMLP(config, device=device, dtype=dtype, ops=ops)
+ norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ hidden = config["hidden_size"]
+ self.input_layernorm = RMSNorm(hidden, **norm_kwargs)
+ self.post_attention_layernorm = RMSNorm(hidden, **norm_kwargs)
+ self.pre_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs)
+ self.post_feedforward_layernorm = RMSNorm(hidden, **norm_kwargs)
+
+ def forward(self, x, freqs, attention_mask=None):
+ residual = x
+ x = self.input_layernorm(x)
+ x = self.self_attn(x, freqs, attention_mask=attention_mask)
+ x = self.post_attention_layernorm(x)
+ x = residual + x
+
+ residual = x
+ x = self.pre_feedforward_layernorm(x)
+ x = self.mlp(x)
+ x = self.post_feedforward_layernorm(x)
+ x = residual + x
+ return x
+
+
+class Gemma4PatchEmbedder(nn.Module):
+ """Patch embedding with learned 2D position embeddings via one-hot lookup."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ patch_size = config["patch_size"]
+ self.patch_size = patch_size
+ self.position_embedding_size = config.get("position_embedding_size", 10240)
+
+ self.input_proj = ops.Linear(3 * patch_size * patch_size, hidden_size, bias=False, device=device, dtype=dtype)
+ self.position_embedding_table = nn.Parameter(
+ torch.empty(2, self.position_embedding_size, hidden_size, device=device, dtype=dtype)
+ )
+
+ def forward(self, patches, pixel_position_ids):
+ """
+ patches: [B, num_patches, 3*patch_size²] in [0,1] range (normalized to [-1,1] inside, matching HF)
+ pixel_position_ids: [B, num_patches, 2] with (x,y) positions, (-1,-1) for padding
+ """
+ hidden_states = self.input_proj((2.0 * (patches - 0.5)).to(self.input_proj.weight.dtype))
+
+ clamped_positions = pixel_position_ids.clamp(min=0)
+ pos_table = comfy.model_management.cast_to_device(self.position_embedding_table, hidden_states.device, hidden_states.dtype)
+ position_embeddings = pos_table[0][clamped_positions[..., 0]] + pos_table[1][clamped_positions[..., 1]]
+
+ # Zero out position embeddings for padding patches (matching HF)
+ padding_positions = (pixel_position_ids == -1).all(dim=-1)
+ position_embeddings = torch.where(padding_positions.unsqueeze(-1), 0.0, position_embeddings)
+
+ return hidden_states + position_embeddings
+
+
+class Gemma4VisionEncoderLayers(nn.Module):
+ """Wrapper to produce state dict keys as encoder.layers.X.*"""
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__()
+ self.layers = nn.ModuleList([
+ Gemma4VisionLayer(config, device=device, dtype=dtype, ops=ops)
+ for _ in range(config["num_hidden_layers"])
+ ])
+
+
+class Gemma4VisionEncoder(nn.Module):
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config["hidden_size"]
+ self.head_dim = config.get("head_dim", config["hidden_size"] // config["num_attention_heads"])
+ self.patch_size = config["patch_size"]
+ self.pooling_kernel_size = config.get("pooling_kernel_size", 3)
+ self.root_hidden_size = self.hidden_size ** 0.5
+
+ self.patch_embedder = Gemma4PatchEmbedder(config, device=device, dtype=dtype, ops=ops)
+ self.encoder = Gemma4VisionEncoderLayers(config, dtype=dtype, device=device, ops=ops)
+
+ def forward(self, pixel_values, max_soft_tokens=None):
+ """
+ pixel_values: [B, C, H, W] in [0,1] range
+ max_soft_tokens: if provided, pad to max_soft_tokens * k² total patches
+ """
+ batch_size, _, height, width = pixel_values.shape
+ ps = self.patch_size
+ k = self.pooling_kernel_size
+ patches_h, patches_w = height // ps, width // ps
+ num_patches = patches_h * patches_w
+ output_length = max_soft_tokens if max_soft_tokens is not None else num_patches // (k * k)
+ n_padding = output_length * k * k - num_patches
+
+ # Patchify and build position grid
+ patches = pixel_values.reshape(batch_size, -1, patches_h, ps, patches_w, ps)
+ patches = patches.permute(0, 2, 4, 3, 5, 1).reshape(batch_size, num_patches, -1)
+ grid_y, grid_x = torch.meshgrid(torch.arange(patches_h, device=pixel_values.device), torch.arange(patches_w, device=pixel_values.device), indexing='ij')
+ position_ids = torch.stack([grid_x.flatten(), grid_y.flatten()], dim=-1).unsqueeze(0).expand(batch_size, -1, -1)
+
+ # Append zero-pixel padding with (-1,-1) positions
+ if n_padding > 0:
+ patches = torch.cat([patches, patches.new_zeros(batch_size, n_padding, patches.shape[-1])], dim=1)
+ position_ids = torch.cat([position_ids, position_ids.new_full((batch_size, n_padding, 2), -1)], dim=1)
+
+ padding = (position_ids == -1).all(dim=-1)
+
+ # Embed, encode, pool
+ x = self.patch_embedder(patches, position_ids)
+ freqs = _compute_vision_2d_rope(self.head_dim, position_ids, device=pixel_values.device)
+ freqs = tuple(t.to(x.dtype) for t in freqs)
+ if n_padding > 0:
+ mask = padding.unsqueeze(1).unsqueeze(2).expand(-1, 1, position_ids.shape[1], -1)
+ mask = torch.zeros_like(mask, dtype=x.dtype).masked_fill_(mask, torch.finfo(x.dtype).min)
+ else:
+ mask = None
+
+ for layer in self.encoder.layers:
+ x = layer(x, freqs, attention_mask=mask)
+
+ if n_padding > 0:
+ x = x.masked_fill(padding.unsqueeze(-1), 0.0)
+
+ # Average pool by spatial position
+ clamped = position_ids.clamp(min=0)
+ max_x = clamped[:, :, 0].max(dim=-1, keepdim=True)[0] + 1
+ ki = torch.div(clamped, k, rounding_mode="floor")
+ ki = ki[:, :, 0] + (max_x // k) * ki[:, :, 1]
+ weights = torch.nn.functional.one_hot(ki.long(), output_length).float() / (k * k)
+ x = (weights.transpose(1, 2) @ x.float()).to(x.dtype)
+
+ # Strip empty output tokens
+ valid_out = ~((weights == 0).all(dim=1))
+ if valid_out.any() and not valid_out.all():
+ x = x[:, valid_out[0]] if batch_size > 1 else x[valid_out].unsqueeze(0)
+
+ return x * self.root_hidden_size
+
+
+class Gemma4RMSNormProjector(nn.Module):
+ """Shared projector: parameterless RMSNorm → linear. Used for both vision and audio."""
+ def __init__(self, in_dim, out_dim, dtype=None, device=None, ops=None):
+ super().__init__()
+ self.embedding_projection = ops.Linear(in_dim, out_dim, bias=False, device=device, dtype=dtype)
+
+ def forward(self, x):
+ return self.embedding_projection(rms_norm(x))
+
+
+class Gemma4MultiModalProjector(Gemma4RMSNormProjector):
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__(config.vision_config["hidden_size"], config.hidden_size, dtype=dtype, device=device, ops=ops)
+
+
+# Audio Encoder
+
+class Gemma4AudioConvSubsampler(nn.Module):
+ """2D convolution subsampling for audio features"""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ eps = config["rms_norm_eps"]
+ self.layer0 = nn.ModuleDict({
+ 'conv': ops.Conv2d(1, 128, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
+ 'norm': ops.LayerNorm(128, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
+ })
+ self.layer1 = nn.ModuleDict({
+ 'conv': ops.Conv2d(128, 32, kernel_size=3, stride=2, padding=1, bias=False, device=device, dtype=dtype),
+ 'norm': ops.LayerNorm(32, eps=eps, elementwise_affine=True, bias=False, device=device, dtype=dtype),
+ })
+ # proj_input_dim = (128 // 4) * 32 = 1024
+ self.input_proj_linear = ops.Linear(1024, config["hidden_size"], bias=False, device=device, dtype=dtype)
+
+ def _conv_layer(self, x, layer, mask):
+ if mask is not None:
+ x = x * mask[:, None, :, None].to(x.device)
+ x = layer['conv'](x.to(layer['conv'].weight.dtype))
+ x = torch.relu(layer['norm'](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous())
+ if mask is not None:
+ mask = mask[:, ::2]
+ return x, mask
+
+ def forward(self, x, mask=None):
+ x = x.unsqueeze(1)
+ x, mask = self._conv_layer(x, self.layer0, mask)
+ x, mask = self._conv_layer(x, self.layer1, mask)
+ batch_size, _, seq_len, _ = x.shape
+ x = x.permute(0, 2, 3, 1).contiguous().reshape(batch_size, seq_len, -1)
+ return self.input_proj_linear(x), mask
+
+
+class Gemma4AudioFeedForward(nn.Module):
+ """Conformer feed-forward with residual scaling."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ intermediate_size = config.get("intermediate_size", hidden_size * 4)
+ self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.ffw_layer_1 = ClippedLinear(hidden_size, intermediate_size, device=device, dtype=dtype, ops=ops)
+ self.ffw_layer_2 = ClippedLinear(intermediate_size, hidden_size, device=device, dtype=dtype, ops=ops)
+ self.post_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.post_layer_scale = config.get("residual_weight", 0.5)
+
+ def forward(self, x):
+ residual = x
+ x = self.pre_layer_norm(x)
+ x = torch.nn.functional.silu(self.ffw_layer_1(x))
+ x = self.ffw_layer_2(x)
+ x = self.post_layer_norm(x)
+ x = x * self.post_layer_scale
+ return x + residual
+
+
+class Gemma4AudioRelPositionalEncoding(nn.Module):
+ """Sinusoidal relative positional encoding for audio attention."""
+ def __init__(self, config, device=None, dtype=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ context_left = config.get("attention_context_left", 13)
+ context_right = config.get("attention_context_right", 0)
+ self.chunk_size = config.get("attention_chunk_size", 12)
+ self.context_size = self.chunk_size + context_left - 1 + context_right
+
+ num_timescales = hidden_size // 2
+ log_inc = math.log(10000.0) / max(num_timescales - 1, 1)
+ inv_timescales = torch.exp(torch.arange(num_timescales) * -log_inc).to(dtype=dtype).unsqueeze(0).unsqueeze(0)
+ self.register_buffer("inv_timescales", inv_timescales, persistent=False)
+
+ def forward(self, hidden_states):
+ positions = torch.arange(self.chunk_size, -1, -1, device=hidden_states.device).unsqueeze(-1)
+ scaled = positions * self.inv_timescales.to(device=hidden_states.device)
+ return torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1).to(dtype=hidden_states.dtype)
+
+
+class Gemma4AudioAttention(nn.Module):
+ """Chunked block attention with relative position bias and softcap."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.hidden_size = config["hidden_size"]
+ self.num_heads = config["num_attention_heads"]
+ self.head_dim = self.hidden_size // self.num_heads
+ self.chunk_size = config.get("attention_chunk_size", 12)
+ self.max_past_horizon = config.get("attention_context_left", 13) - 1
+ self.max_future_horizon = config.get("attention_context_right", 0)
+ self.context_size = self.chunk_size + self.max_past_horizon + self.max_future_horizon
+
+ self.q_scale = (self.head_dim ** -0.5) / math.log(2)
+ self.k_scale = math.log(1 + math.e) / math.log(2)
+ self.register_buffer("softcap", torch.tensor(config.get("attention_logit_cap", 50.0), dtype=dtype), persistent=False)
+
+ self.q_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
+ self.k_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
+ self.v_proj = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
+ self.post = ClippedLinear(self.hidden_size, self.hidden_size, device=device, dtype=dtype, ops=ops)
+ self.per_dim_scale = nn.Parameter(torch.empty(self.head_dim, device=device, dtype=dtype))
+ self.relative_k_proj = ops.Linear(self.hidden_size, self.hidden_size, bias=False, device=device, dtype=dtype)
+
+ def _convert_to_block(self, x):
+ B, S, H, D = x.shape
+ num_blocks = (S + self.chunk_size - 1) // self.chunk_size
+ pad = num_blocks * self.chunk_size - S
+ x = torch.nn.functional.pad(x, (0, 0, 0, 0, 0, pad))
+ return x.reshape(B, num_blocks, self.chunk_size, H, D).contiguous()
+
+ def _extract_block_context(self, x):
+ x = torch.nn.functional.pad(x, (0, 0, 0, 0, self.max_past_horizon, self.max_future_horizon + self.chunk_size - 1))
+ x = x.unfold(1, self.context_size, self.chunk_size)
+ return torch.movedim(x, -1, 2).contiguous()
+
+ def _rel_shift(self, x):
+ B, H, NB, BS, PL = x.shape
+ CS = self.context_size
+ x = torch.nn.functional.pad(x, (0, CS + 1 - PL))
+ x = x.view(B, H, NB, BS * (CS + 1))
+ x = x[..., :BS * CS]
+ return x.view(B, H, NB, BS, CS)
+
+ def _build_blocked_mask(self, seq_len, num_blocks, device, audio_mask=None):
+ """Build 5D boolean blocked attention mask (True=attend, False=mask)"""
+ q = torch.arange(seq_len, device=device)
+ dist = q[:, None] - q[None, :]
+ mask = (dist >= 0) & (dist < self.max_past_horizon)
+ if self.max_future_horizon > 0:
+ mask = mask | ((dist < 0) & ((-dist) < self.max_future_horizon))
+ if audio_mask is not None:
+ mask = mask & audio_mask[0, None, :].bool()
+ m = mask[None, None]
+ # Reshape to blocked 5D matching reference code
+ p = num_blocks * self.chunk_size - seq_len
+ m = torch.nn.functional.pad(m, (0, p, 0, p), value=False)
+ m = m.reshape(1, 1, num_blocks, self.chunk_size, -1)
+ m = torch.nn.functional.pad(m, (self.max_past_horizon, self.max_future_horizon), value=False)
+ idx = (torch.arange(num_blocks, device=device) * self.chunk_size)[:, None] + torch.arange(self.context_size, device=device)[None, :]
+ return m.gather(-1, idx[None, None, :, None, :].expand(1, 1, -1, self.chunk_size, -1))
+
+ def forward(self, x, position_embeddings=None, attn_mask=None):
+ B, S, _ = x.shape
+
+ q = self.q_proj(x).float().view(B, S, self.num_heads, self.head_dim)
+ k = self.k_proj(x).float().view(B, S, self.num_heads, self.head_dim)
+ v = self.v_proj(x).float().view(B, S, self.num_heads, self.head_dim)
+
+ q = q * self.q_scale * torch.nn.functional.softplus(self.per_dim_scale)
+ k = k * self.k_scale
+
+ q_blocks = self._convert_to_block(q)
+ k_context = self._extract_block_context(k)
+ v_context = self._extract_block_context(v)
+ num_blocks = q_blocks.shape[1]
+
+ rel_k = self.relative_k_proj(position_embeddings).view(-1, self.num_heads, self.head_dim).to(q.dtype)
+
+ queries = q_blocks.permute(0, 3, 1, 2, 4) # [B, H, NB, CS, D]
+ matrix_ac = queries @ k_context.permute(0, 3, 1, 4, 2)
+
+ queries_flat = queries.reshape(B, self.num_heads, -1, self.head_dim)
+ matrix_bd = queries_flat @ rel_k.permute(1, 2, 0)
+ matrix_bd = matrix_bd.reshape(B, self.num_heads, num_blocks, self.chunk_size, -1)
+ matrix_bd = self._rel_shift(matrix_bd)
+
+ attn_weights = matrix_ac + matrix_bd
+ attn_weights = torch.tanh(attn_weights / self.softcap) * self.softcap
+
+ # Mask out invalid positions in chunk context (matching reference's masked_fill approach)
+ if attn_mask is None:
+ attn_mask = self._build_blocked_mask(S, num_blocks, x.device)
+ attn_weights = attn_weights.masked_fill(attn_mask.logical_not(), -1e9)
+
+ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(v.dtype)
+ out = attn_weights @ v_context.permute(0, 3, 1, 2, 4)
+ out = out.permute(0, 2, 3, 1, 4).reshape(B, num_blocks * self.chunk_size, -1)
+ out = out[:, :S].contiguous()
+ return self.post(out.to(self.post.linear.weight.dtype))
+
+
+class Gemma4AudioLConv1d(nn.Module):
+ """Lightweight convolution with standard GLU."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ hidden_size = config["hidden_size"]
+ conv_kernel_size = config.get("conv_kernel_size", 5)
+ self.pre_layer_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.linear_start = ClippedLinear(hidden_size, hidden_size * 2, device=device, dtype=dtype, ops=ops)
+ # Causal conv: left-pad only
+ self.depthwise_conv1d = ops.Conv1d(hidden_size, hidden_size, kernel_size=conv_kernel_size, padding=0, groups=hidden_size, bias=False, device=device, dtype=dtype)
+ self.conv_left_pad = conv_kernel_size - 1 # causal: pad left by kernel-1
+ self.conv_norm = RMSNorm(hidden_size, eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ self.linear_end = ClippedLinear(hidden_size, hidden_size, device=device, dtype=dtype, ops=ops)
+
+ def forward(self, x):
+ residual = x
+ x = self.pre_layer_norm(x)
+ x = self.linear_start(x)
+ x = torch.nn.functional.glu(x, dim=-1)
+ x = x.transpose(1, 2)
+ x = torch.nn.functional.pad(x, (self.conv_left_pad, 0))
+ x = self.depthwise_conv1d(x).transpose(1, 2)
+ x = self.conv_norm(x)
+ x = torch.nn.functional.silu(x)
+ x = self.linear_end(x)
+ return x + residual
+
+
+class Gemma4AudioLayer(nn.Module):
+ """Conformer block: FFN1 -> Attention -> LConv -> FFN2."""
+ def __init__(self, config, device=None, dtype=None, ops=None):
+ super().__init__()
+ self.feed_forward1 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
+ self.self_attn = Gemma4AudioAttention(config, device=device, dtype=dtype, ops=ops)
+ norm_kwargs = dict(eps=config["rms_norm_eps"], device=device, dtype=dtype)
+ hidden_size = config["hidden_size"]
+ self.norm_pre_attn = RMSNorm(hidden_size, **norm_kwargs)
+ self.norm_post_attn = RMSNorm(hidden_size, **norm_kwargs)
+ self.lconv1d = Gemma4AudioLConv1d(config, device=device, dtype=dtype, ops=ops)
+ self.feed_forward2 = Gemma4AudioFeedForward(config, device=device, dtype=dtype, ops=ops)
+ self.norm_out = RMSNorm(hidden_size, **norm_kwargs)
+
+ def forward(self, x, position_embeddings=None, attn_mask=None):
+ x = self.feed_forward1(x)
+
+ residual = x
+ x = self.norm_pre_attn(x)
+ x = self.self_attn(x, position_embeddings=position_embeddings, attn_mask=attn_mask)
+ x = self.norm_post_attn(x)
+ x = x + residual
+
+ x = self.lconv1d(x)
+ x = self.feed_forward2(x)
+
+ x = self.norm_out(x)
+ return x
+
+
+class Gemma4AudioEncoder(nn.Module):
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__()
+ self.hidden_size = config["hidden_size"]
+ self.output_proj_dims = config.get("output_proj_dims", 1536)
+
+ self.subsample_conv_projection = Gemma4AudioConvSubsampler(config, device=device, dtype=dtype, ops=ops)
+ self.rel_pos_enc = Gemma4AudioRelPositionalEncoding(config, device=device, dtype=dtype)
+
+ self.layers = nn.ModuleList([
+ Gemma4AudioLayer(config, device=device, dtype=dtype, ops=ops)
+ for _ in range(config["num_hidden_layers"])
+ ])
+
+ self.output_proj = ops.Linear(self.hidden_size, self.output_proj_dims, bias=True, device=device, dtype=dtype)
+
+ def forward(self, audio_features, audio_mask=None):
+ x, audio_mask = self.subsample_conv_projection(audio_features, audio_mask)
+ position_embeddings = self.rel_pos_enc(x)
+
+ # Build blocked attention mask once for all layers
+ attn_mask = self.layers[0].self_attn._build_blocked_mask(
+ x.shape[1], (x.shape[1] + self.layers[0].self_attn.chunk_size - 1) // self.layers[0].self_attn.chunk_size,
+ x.device, audio_mask=audio_mask)
+
+ for layer in self.layers:
+ x = layer(x, position_embeddings=position_embeddings, attn_mask=attn_mask)
+
+ x = self.output_proj(x)
+ return x
+
+
+class Gemma4AudioProjector(Gemma4RMSNormProjector):
+ def __init__(self, config, dtype=None, device=None, ops=None):
+ super().__init__(config.get("audio_output_proj_dims", 1536), config.get("text_hidden_size", 2560), dtype=dtype, device=device, ops=ops)
+
+
+# Tokenizer and Wrappers
+
+class Gemma4_Tokenizer():
+ tokenizer_json_data = None
+
+ def state_dict(self):
+ if self.tokenizer_json_data is not None:
+ return {"tokenizer_json": self.tokenizer_json_data}
+ return {}
+
+ def _extract_mel_spectrogram(self, waveform, sample_rate):
+ """Extract 128-bin log mel spectrogram.
+ Uses numpy for FFT/matmul/log to produce bit-identical results with reference code.
+ """
+ # Mix to mono first, then resample to 16kHz
+ if waveform.dim() > 1 and waveform.shape[0] > 1:
+ waveform = waveform.mean(dim=0, keepdim=True)
+ if waveform.dim() == 1:
+ waveform = waveform.unsqueeze(0)
+ audio = waveform.squeeze(0).float().numpy()
+ if sample_rate != 16000:
+ # Use scipy's resample_poly with a high-quality FIR filter to get as close as possible to librosa's resampling (while still not full match)
+ from scipy.signal import resample_poly, firwin
+ from math import gcd
+ g = gcd(sample_rate, 16000)
+ up, down = 16000 // g, sample_rate // g
+ L = max(up, down)
+ h = firwin(160 * L + 1, 0.96 / L, window=('kaiser', 6.5))
+ audio = resample_poly(audio, up, down, window=h).astype(np.float32)
+ n = len(audio)
+
+ # Pad to multiple of 128, build sample-level mask
+ if n % 128 != 0:
+ audio = np.pad(audio, (0, 128 - n % 128))
+ mask_raw = np.ones(len(audio), dtype=np.float32)
+ mask_raw[n:] = 0.0
+
+ # Semicausal padding: 160 zeros prepended
+ audio = np.pad(audio, (160, 0))
+ mask_raw = np.pad(mask_raw, (160, 0))
+
+ # Extract 321-sample frames via stride tricks, drop last → 320
+ nf = (len(audio) - 321) // 160 + 1
+ strides = (audio.strides[0] * 160, audio.strides[0])
+ frames = np.lib.stride_tricks.as_strided(audio, (nf, 321), strides)[..., :-1].copy()
+
+ # Periodic Hann window, FFT magnitude, mel filterbank, log
+ window = (0.5 - 0.5 * np.cos(2 * np.pi * np.arange(320) / 320)).astype(np.float32)
+ magnitude = np.abs(np.fft.rfft(frames * window, n=512, axis=-1))
+ mel_fb = self._build_mel_filterbank()
+ log_mel = np.log(np.matmul(magnitude, mel_fb) + np.float64(0.001)).astype(np.float32)
+
+ # Frame mask: valid when last sample in window is real audio
+ mask = mask_raw[np.arange(nf) * 160 + 320].astype(bool)
+ log_mel = log_mel * mask[:, None]
+ return torch.from_numpy(log_mel), torch.from_numpy(mask) # [T, 128], [T]
+
+ @staticmethod
+ def _build_mel_filterbank():
+ """Build 128-bin HTK mel filterbank [257, 128] for 512-pt FFT at 16kHz."""
+ mel_freqs = np.linspace(0.0, 2595.0 * np.log10(1.0 + 8000.0 / 700.0), 130)
+ filter_freqs = 700.0 * (10.0 ** (mel_freqs / 2595.0) - 1.0)
+ fft_freqs = np.linspace(0, 16000 // 2, 257)
+ filter_diff = np.diff(filter_freqs)
+ slopes = np.expand_dims(filter_freqs, 0) - np.expand_dims(fft_freqs, 1)
+ down_slopes = -slopes[:, :-2] / filter_diff[:-1]
+ up_slopes = slopes[:, 2:] / filter_diff[1:]
+ return np.maximum(np.zeros(1), np.minimum(down_slopes, up_slopes))
+
+ def tokenize_with_weights(self, text, return_word_ids=False, image=None, audio=None, video=None, llama_template=None, skip_template=True, thinking=False, **kwargs):
+
+ # Process audio
+ audio_features = []
+ if audio is not None:
+ waveform = audio["waveform"].squeeze(0) if hasattr(audio, "__getitem__") else audio
+ sample_rate = audio.get("sample_rate", 16000) if hasattr(audio, "get") else 16000
+ mel, mel_mask = self._extract_mel_spectrogram(waveform, sample_rate)
+ audio_features = [(mel.unsqueeze(0), mel_mask.unsqueeze(0))] # ([1, T, 128], [1, T])
+
+ # Process image/video frames
+ is_video = video is not None
+ source = video if is_video else image
+ images = []
+ if source is not None:
+ samples = source.movedim(-1, 1) # [B, C, H, W]
+ num_frames = samples.shape[0]
+
+ # Subsample video to 1fps
+ if is_video:
+ fps = kwargs.get("fps", 24)
+ step = max(1, round(fps))
+ indices = list(range(0, num_frames, step))
+ if len(indices) == 0:
+ indices = [0]
+ samples = samples[indices]
+ num_frames = len(indices)
+
+ h, w = samples.shape[2], samples.shape[3]
+ patch_size = 16
+ pooling_k = 3
+ max_soft_tokens = 70 if is_video else 280 # video uses smaller token budget per frame
+ max_patches = max_soft_tokens * pooling_k * pooling_k
+ target_px = max_patches * patch_size * patch_size
+ factor = (target_px / (h * w)) ** 0.5
+ side_mult = pooling_k * patch_size
+ target_h = max(int(factor * h // side_mult) * side_mult, side_mult)
+ target_w = max(int(factor * w // side_mult) * side_mult, side_mult)
+
+ import torchvision.transforms.functional as TVF
+ for i in range(num_frames):
+ # rescaling to match reference code
+ s = (samples[i].clamp(0, 1) * 255).to(torch.uint8) # [C, H, W] uint8
+ if target_h != h or target_w != w:
+ s = TVF.resize(s, [target_h, target_w], interpolation=TVF.InterpolationMode.BICUBIC, antialias=True)
+ s = s.float() * (1.0 / 255.0)
+ images.append({"pixels": s.unsqueeze(0).movedim(1, -1)[:, :, :, :3], "max_soft_tokens": max_soft_tokens})
+
+ if text.startswith('<|turn>'):
+ skip_template = True
+
+ if skip_template:
+ llama_text = text
+ else:
+ if llama_template is not None:
+ llama_text = llama_template.format(text)
+ else:
+ # Build template from modalities present
+ system = "<|turn>system\n<|think|>\n" if thinking else ""
+ media = ""
+ if len(images) > 0:
+ if is_video:
+ media += "\n\n"
+ for i in range(len(images)):
+ ts = f"{int(i // 60):02d}:{int(i % 60):02d}"
+ sep = "" if i == 0 else " "
+ media += f"{sep}{ts} <|image><|video|>"
+ media += "\n\n"
+ else:
+ media += "\n\n"
+ for i in range(len(images)):
+ if i > 0:
+ media += "\n\n\n\n"
+ media += "<|image><|image|>"
+ media += "\n\n"
+ if len(audio_features) > 0:
+ # Compute audio token count (always at 16kHz)
+ num_samples = int(waveform.shape[-1] * 16000 / sample_rate) if sample_rate != 16000 else waveform.shape[-1]
+ _fl = 320 # int(round(16000 * 20.0 / 1000.0))
+ _hl = 160 # int(round(16000 * 10.0 / 1000.0))
+ _nmel = (num_samples + _fl // 2 - (_fl + 1)) // _hl + 1
+ _t = _nmel
+ for _ in range(2):
+ _t = (_t + 2 - 3) // 2 + 1
+ n_audio_tokens = min(_t, 750)
+ media += "<|audio>" + "<|audio|>" * n_audio_tokens + "