Merge branch 'master' into trellis2

This commit is contained in:
Yousef R. Gamaleldin 2026-05-05 23:20:05 +03:00 committed by GitHub
commit 81ed835ffb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
153 changed files with 46086 additions and 2638 deletions

View File

@ -1,2 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --disable-smart-memory .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --enable-dynamic-vram
pause pause

View File

@ -0,0 +1,45 @@
name: Tag Dispatch to Cloud
on:
push:
tags:
- 'v*'
jobs:
dispatch-cloud:
runs-on: ubuntu-latest
steps:
- name: Send repository dispatch to cloud
env:
DISPATCH_TOKEN: ${{ secrets.CLOUD_REPO_DISPATCH_TOKEN }}
RELEASE_TAG: ${{ github.ref_name }}
run: |
set -euo pipefail
if [ -z "${DISPATCH_TOKEN:-}" ]; then
echo "::error::CLOUD_REPO_DISPATCH_TOKEN is required but not set."
exit 1
fi
RELEASE_URL="https://github.com/${{ github.repository }}/releases/tag/${RELEASE_TAG}"
PAYLOAD="$(jq -n \
--arg release_tag "$RELEASE_TAG" \
--arg release_url "$RELEASE_URL" \
'{
event_type: "comfyui_tag_pushed",
client_payload: {
release_tag: $release_tag,
release_url: $release_url
}
}')"
curl -fsSL \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
https://api.github.com/repos/Comfy-Org/cloud/dispatches \
-d "$PAYLOAD"
echo "✅ Dispatched ComfyUI tag ${RELEASE_TAG} to Comfy-Org/cloud"

2
.gitignore vendored
View File

@ -21,6 +21,6 @@ venv*/
*.log *.log
web_custom_versions/ web_custom_versions/
.DS_Store .DS_Store
openapi.yaml
filtered-openapi.yaml filtered-openapi.yaml
uv.lock uv.lock
.comfy_environment

View File

@ -1,2 +1,2 @@
# Admins # Admins
* @comfyanonymous @kosinkadink @guill * @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai

View File

@ -139,9 +139,9 @@ Example:
"_quantization_metadata": { "_quantization_metadata": {
"format_version": "1.0", "format_version": "1.0",
"layers": { "layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn", "model.layers.0.mlp.up_proj": {"format": "float8_e4m3fn"},
"model.layers.0.mlp.down_proj": "float8_e4m3fn", "model.layers.0.mlp.down_proj": {"format": "float8_e4m3fn"},
"model.layers.1.mlp.up_proj": "float8_e4m3fn" "model.layers.1.mlp.up_proj": {"format": "float8_e4m3fn"}
} }
} }
} }
@ -165,4 +165,4 @@ Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_s
3. **Compute scales**: Derive `input_scale` from collected statistics 3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights 4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

View File

@ -1,7 +1,7 @@
<div align="center"> <div align="center">
# ComfyUI # ComfyUI
**The most powerful and modular visual AI engine and application.** **The most powerful and modular AI engine for content creation.**
[![Website][website-shield]][website-url] [![Website][website-shield]][website-url]
@ -31,10 +31,16 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest [github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases [github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe) <img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
<br>
</div> </div>
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS. ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints.
## Get Started ## Get Started
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Ernie Image
- Image Editing Models - Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
@ -126,7 +133,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) roughly every week. - Releases a new major stable version (e.g., v0.7.0) roughly every 2 weeks.
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release. - Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
- Minor versions will be used for releases off the master branch. - Minor versions will be used for releases off the master branch.
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense. - Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
@ -193,11 +200,15 @@ If you have trouble extracting it, right click the file -> properties -> unblock
The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start. The portable above currently comes with python 3.13 and pytorch cuda 13.0. Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads: #### All Official Portable Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) [Portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs). [Portable for Intel GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_intel.7z)
[Portable for Nvidia GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) (supports 20 series and above).
[Portable for Nvidia GPUs with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI? #### How do I share models between another UI and ComfyUI?

View File

@ -67,7 +67,7 @@ class InternalRoutes:
(entry for entry in os.scandir(directory) if is_visible_file(entry)), (entry for entry in os.scandir(directory) if is_visible_file(entry)),
key=lambda entry: -entry.stat().st_mtime key=lambda entry: -entry.stat().st_mtime
) )
return web.json_response([entry.name for entry in sorted_files], status=200) return web.json_response([f"{entry.name} [{directory_type}]" for entry in sorted_files], status=200)
def get_app(self): def get_app(self):

View File

@ -2,7 +2,6 @@
precision mediump float; precision mediump float;
uniform sampler2D u_image0; uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blend mode uniform int u_int0; // Blend mode
uniform int u_int1; // Color tint uniform int u_int1; // Color tint
uniform float u_float0; // Intensity uniform float u_float0; // Intensity
@ -75,7 +74,7 @@ void main() {
float t0 = threshold - 0.15; float t0 = threshold - 0.15;
float t1 = threshold + 0.15; float t1 = threshold + 0.15;
vec2 texelSize = 1.0 / u_resolution; vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radius2 = radius * radius; float radius2 = radius * radius;
float sampleScale = clamp(radius * 0.75, 0.35, 1.0); float sampleScale = clamp(radius * 0.75, 0.35, 1.0);

View File

@ -12,7 +12,6 @@ const int RADIAL_SAMPLES = 12;
const float RADIAL_STRENGTH = 0.0003; const float RADIAL_STRENGTH = 0.0003;
uniform sampler2D u_image0; uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL) uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
uniform float u_float0; // Blur radius/amount uniform float u_float0; // Blur radius/amount
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical) uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
@ -25,7 +24,7 @@ float gaussian(float x, float sigma) {
} }
void main() { void main() {
vec2 texelSize = 1.0 / u_resolution; vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radius = max(u_float0, 0.0); float radius = max(u_float0, 0.0);
// Radial (angular) blur - single pass, doesn't use separable // Radial (angular) blur - single pass, doesn't use separable

View File

@ -2,14 +2,13 @@
precision highp float; precision highp float;
uniform sampler2D u_image0; uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // strength [0.0 2.0] typical: 0.31.0 uniform float u_float0; // strength [0.0 2.0] typical: 0.31.0
in vec2 v_texCoord; in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0; layout(location = 0) out vec4 fragColor0;
void main() { void main() {
vec2 texel = 1.0 / u_resolution; vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));
// Sample center and neighbors // Sample center and neighbors
vec4 center = texture(u_image0, v_texCoord); vec4 center = texture(u_image0, v_texCoord);

View File

@ -2,7 +2,6 @@
precision highp float; precision highp float;
uniform sampler2D u_image0; uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5 uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
@ -19,7 +18,7 @@ float getLuminance(vec3 color) {
} }
void main() { void main() {
vec2 texel = 1.0 / u_resolution; vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));
float radius = max(u_float1, 0.5); float radius = max(u_float1, 0.5);
float amount = u_float0; float amount = u_float0;
float threshold = u_float2; float threshold = u_float2;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -160,7 +160,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Depth to Image (Z-Image-Turbo)", "name": "Depth to Image (Z-Image-Turbo)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [
@ -2482,4 +2482,4 @@
"VHS_KeepIntermediate": true "VHS_KeepIntermediate": true
}, },
"version": 0.4 "version": 0.4
} }

View File

@ -261,7 +261,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Depth to Video (LTX 2.0)", "name": "Depth to Video (LTX 2.0)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [
@ -5208,4 +5208,4 @@
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"version": 0.4 "version": 0.4
} }

File diff suppressed because it is too large Load Diff

View File

@ -268,7 +268,7 @@
"Node name for S&R": "GLSLShader" "Node name for S&R": "GLSLShader"
}, },
"widgets_values": [ "widgets_values": [
"#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / u_resolution;\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}", "#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}",
"from_input" "from_input"
] ]
}, },

View File

@ -331,7 +331,7 @@
"Node name for S&R": "GLSLShader" "Node name for S&R": "GLSLShader"
}, },
"widgets_values": [ "widgets_values": [
"#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / u_resolution;\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n", "#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n",
"from_input" "from_input"
] ]
} }

File diff suppressed because it is too large Load Diff

View File

@ -128,7 +128,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Image Edit (Flux.2 Klein 4B)", "name": "Image Edit (Flux.2 Klein 4B)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [
@ -1837,4 +1837,4 @@
} }
}, },
"version": 0.4 "version": 0.4
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -124,7 +124,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Image Inpainting (Qwen-image)", "name": "Image Inpainting (Qwen-image)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [
@ -1923,4 +1923,4 @@
"workflowRendererVersion": "LG" "workflowRendererVersion": "LG"
}, },
"version": 0.4 "version": 0.4
} }

View File

@ -204,7 +204,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Image Outpainting (Qwen-Image)", "name": "Image Outpainting (Qwen-Image)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [
@ -2749,4 +2749,4 @@
} }
}, },
"version": 0.4 "version": 0.4
} }

View File

@ -1,15 +1,14 @@
{ {
"id": "1a761372-7c82-4016-b9bf-fa285967e1e9",
"revision": 0, "revision": 0,
"last_node_id": 83, "last_node_id": 176,
"last_link_id": 0, "last_link_id": 0,
"nodes": [ "nodes": [
{ {
"id": 83, "id": 176,
"type": "f754a936-daaf-4b6e-9658-41fdc54d301d", "type": "2d2e3c8e-53b3-4618-be52-6d1d99382f0e",
"pos": [ "pos": [
61.999827823554256, -1150,
153.3332507624185 200
], ],
"size": [ "size": [
400, 400,
@ -56,6 +55,38 @@
"name": "layers" "name": "layers"
}, },
"link": null "link": null
},
{
"name": "seed",
"type": "INT",
"widget": {
"name": "seed"
},
"link": null
},
{
"name": "unet_name",
"type": "COMBO",
"widget": {
"name": "unet_name"
},
"link": null
},
{
"name": "clip_name",
"type": "COMBO",
"widget": {
"name": "clip_name"
},
"link": null
},
{
"name": "vae_name",
"type": "COMBO",
"widget": {
"name": "vae_name"
},
"link": null
} }
], ],
"outputs": [ "outputs": [
@ -66,28 +97,41 @@
"links": [] "links": []
} }
], ],
"title": "Image to Layers (Qwen-Image-Layered)",
"properties": { "properties": {
"proxyWidgets": [ "proxyWidgets": [
[ [
"-1", "6",
"text" "text"
], ],
[ [
"-1", "3",
"steps" "steps"
], ],
[ [
"-1", "3",
"cfg" "cfg"
], ],
[ [
"-1", "83",
"layers" "layers"
], ],
[ [
"3", "3",
"seed" "seed"
], ],
[
"37",
"unet_name"
],
[
"38",
"clip_name"
],
[
"39",
"vae_name"
],
[ [
"3", "3",
"control_after_generate" "control_after_generate"
@ -95,6 +139,11 @@
], ],
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -103,25 +152,20 @@
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, },
"widgets_values": [ "widgets_values": []
"",
20,
2.5,
2
]
} }
], ],
"links": [], "links": [],
"groups": [], "version": 0.4,
"definitions": { "definitions": {
"subgraphs": [ "subgraphs": [
{ {
"id": "f754a936-daaf-4b6e-9658-41fdc54d301d", "id": "2d2e3c8e-53b3-4618-be52-6d1d99382f0e",
"version": 1, "version": 1,
"state": { "state": {
"lastGroupId": 3, "lastGroupId": 8,
"lastNodeId": 83, "lastNodeId": 176,
"lastLinkId": 159, "lastLinkId": 380,
"lastRerouteId": 0 "lastRerouteId": 0
}, },
"revision": 0, "revision": 0,
@ -130,10 +174,10 @@
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [
-510, -720,
523, 720,
120, 120,
140 220
] ]
}, },
"outputNode": { "outputNode": {
@ -156,8 +200,8 @@
], ],
"localized_name": "image", "localized_name": "image",
"pos": [ "pos": [
-410, -620,
543 740
] ]
}, },
{ {
@ -168,8 +212,8 @@
150 150
], ],
"pos": [ "pos": [
-410, -620,
563 760
] ]
}, },
{ {
@ -180,8 +224,8 @@
153 153
], ],
"pos": [ "pos": [
-410, -620,
583 780
] ]
}, },
{ {
@ -192,8 +236,8 @@
154 154
], ],
"pos": [ "pos": [
-410, -620,
603 800
] ]
}, },
{ {
@ -204,8 +248,56 @@
159 159
], ],
"pos": [ "pos": [
-410, -620,
623 820
]
},
{
"id": "9f76338b-f4ca-4bb3-b61a-57b3f233061e",
"name": "seed",
"type": "INT",
"linkIds": [
377
],
"pos": [
-620,
840
]
},
{
"id": "8d0422d5-5eee-4f7e-9817-dc613cc62eca",
"name": "unet_name",
"type": "COMBO",
"linkIds": [
378
],
"pos": [
-620,
860
]
},
{
"id": "552eece2-a735-4d00-ae78-ded454622bc1",
"name": "clip_name",
"type": "COMBO",
"linkIds": [
379
],
"pos": [
-620,
880
]
},
{
"id": "1e6d141c-d0f9-4a2b-895c-b6780e57cfa0",
"name": "vae_name",
"type": "COMBO",
"linkIds": [
380
],
"pos": [
-620,
900
] ]
} }
], ],
@ -231,14 +323,14 @@
"type": "CLIPLoader", "type": "CLIPLoader",
"pos": [ "pos": [
-320, -320,
310 360
], ],
"size": [ "size": [
346.7470703125, 350,
106 150
], ],
"flags": {}, "flags": {},
"order": 0, "order": 5,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -248,7 +340,7 @@
"widget": { "widget": {
"name": "clip_name" "name": "clip_name"
}, },
"link": null "link": 379
}, },
{ {
"localized_name": "type", "localized_name": "type",
@ -283,9 +375,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "CLIPLoader",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "CLIPLoader",
"models": [ "models": [
{ {
"name": "qwen_2.5_vl_7b_fp8_scaled.safetensors", "name": "qwen_2.5_vl_7b_fp8_scaled.safetensors",
@ -312,14 +409,14 @@
"type": "VAELoader", "type": "VAELoader",
"pos": [ "pos": [
-320, -320,
460 580
], ],
"size": [ "size": [
346.7470703125, 350,
58 110
], ],
"flags": {}, "flags": {},
"order": 1, "order": 6,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -329,7 +426,7 @@
"widget": { "widget": {
"name": "vae_name" "name": "vae_name"
}, },
"link": null "link": 380
} }
], ],
"outputs": [ "outputs": [
@ -345,9 +442,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "VAELoader",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "VAELoader",
"models": [ "models": [
{ {
"name": "qwen_image_layered_vae.safetensors", "name": "qwen_image_layered_vae.safetensors",
@ -375,11 +477,11 @@
420 420
], ],
"size": [ "size": [
425.27801513671875, 430,
180.6060791015625 190
], ],
"flags": {}, "flags": {},
"order": 3, "order": 2,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -411,9 +513,14 @@
], ],
"title": "CLIP Text Encode (Negative Prompt)", "title": "CLIP Text Encode (Negative Prompt)",
"properties": { "properties": {
"Node name for S&R": "CLIPTextEncode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "CLIPTextEncode",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -432,12 +539,12 @@
"id": 70, "id": 70,
"type": "ReferenceLatent", "type": "ReferenceLatent",
"pos": [ "pos": [
330, 140,
670 700
], ],
"size": [ "size": [
204.1666717529297, 210,
46 50
], ],
"flags": { "flags": {
"collapsed": true "collapsed": true
@ -470,9 +577,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "ReferenceLatent",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "ReferenceLatent",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -480,19 +592,18 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 69, "id": 69,
"type": "ReferenceLatent", "type": "ReferenceLatent",
"pos": [ "pos": [
330, 160,
710 820
], ],
"size": [ "size": [
204.1666717529297, 210,
46 50
], ],
"flags": { "flags": {
"collapsed": true "collapsed": true
@ -525,9 +636,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "ReferenceLatent",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "ReferenceLatent",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -535,8 +651,7 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 66, "id": 66,
@ -547,10 +662,10 @@
], ],
"size": [ "size": [
270, 270,
58 110
], ],
"flags": {}, "flags": {},
"order": 4, "order": 7,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -580,9 +695,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "ModelSamplingAuraFlow",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "ModelSamplingAuraFlow",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -600,11 +720,11 @@
"type": "LatentCutToBatch", "type": "LatentCutToBatch",
"pos": [ "pos": [
830, 830,
160 140
], ],
"size": [ "size": [
270, 270,
82 140
], ],
"flags": {}, "flags": {},
"order": 11, "order": 11,
@ -646,9 +766,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "LatentCutToBatch",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "LatentCutToBatch",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -666,12 +791,12 @@
"id": 71, "id": 71,
"type": "VAEEncode", "type": "VAEEncode",
"pos": [ "pos": [
100, -280,
690 780
], ],
"size": [ "size": [
140, 230,
46 100
], ],
"flags": { "flags": {
"collapsed": false "collapsed": false
@ -704,9 +829,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "VAEEncode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "VAEEncode",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -714,24 +844,23 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 8, "id": 8,
"type": "VAEDecode", "type": "VAEDecode",
"pos": [ "pos": [
850, 850,
310 370
], ],
"size": [ "size": [
210, 210,
46 50
], ],
"flags": { "flags": {
"collapsed": true "collapsed": true
}, },
"order": 7, "order": 3,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -759,9 +888,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "VAEDecode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "VAEDecode",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -769,8 +903,7 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 6, "id": 6,
@ -780,11 +913,11 @@
180 180
], ],
"size": [ "size": [
422.84503173828125, 430,
164.31304931640625 170
], ],
"flags": {}, "flags": {},
"order": 6, "order": 1,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -816,9 +949,14 @@
], ],
"title": "CLIP Text Encode (Positive Prompt)", "title": "CLIP Text Encode (Positive Prompt)",
"properties": { "properties": {
"Node name for S&R": "CLIPTextEncode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "CLIPTextEncode",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -838,14 +976,14 @@
"type": "KSampler", "type": "KSampler",
"pos": [ "pos": [
530, 530,
280 340
], ],
"size": [ "size": [
270, 270,
400 400
], ],
"flags": {}, "flags": {},
"order": 5, "order": 0,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -879,7 +1017,7 @@
"widget": { "widget": {
"name": "seed" "name": "seed"
}, },
"link": null "link": 377
}, },
{ {
"localized_name": "steps", "localized_name": "steps",
@ -939,9 +1077,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "KSampler",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "KSampler",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -964,12 +1107,12 @@
"id": 78, "id": 78,
"type": "GetImageSize", "type": "GetImageSize",
"pos": [ "pos": [
80, -280,
790 930
], ],
"size": [ "size": [
210, 230,
136 140
], ],
"flags": {}, "flags": {},
"order": 12, "order": 12,
@ -1007,9 +1150,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "GetImageSize",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "GetImageSize",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -1017,23 +1165,23 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 83, "id": 83,
"type": "EmptyQwenImageLayeredLatentImage", "type": "EmptyQwenImageLayeredLatentImage",
"pos": [ "pos": [
320, -280,
790 1120
], ],
"size": [ "size": [
330.9341796875, 340,
130 200
], ],
"flags": {}, "flags": {},
"order": 13, "order": 13,
"mode": 0, "mode": 0,
"showAdvanced": true,
"inputs": [ "inputs": [
{ {
"localized_name": "width", "localized_name": "width",
@ -1083,9 +1231,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "EmptyQwenImageLayeredLatentImage",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "EmptyQwenImageLayeredLatentImage",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -1109,11 +1262,11 @@
180 180
], ],
"size": [ "size": [
346.7470703125, 350,
82 110
], ],
"flags": {}, "flags": {},
"order": 2, "order": 4,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -1123,7 +1276,7 @@
"widget": { "widget": {
"name": "unet_name" "name": "unet_name"
}, },
"link": null "link": 378
}, },
{ {
"localized_name": "weight_dtype", "localized_name": "weight_dtype",
@ -1147,9 +1300,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "UNETLoader",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "UNETLoader",
"models": [ "models": [
{ {
"name": "qwen_image_layered_bf16.safetensors", "name": "qwen_image_layered_bf16.safetensors",
@ -1191,8 +1349,8 @@
"bounding": [ "bounding": [
-330, -330,
110, 110,
366.7470703125, 370,
421.6 610
], ],
"color": "#3f789e", "color": "#3f789e",
"font_size": 24, "font_size": 24,
@ -1391,6 +1549,38 @@
"target_id": 83, "target_id": 83,
"target_slot": 2, "target_slot": 2,
"type": "INT" "type": "INT"
},
{
"id": 377,
"origin_id": -10,
"origin_slot": 5,
"target_id": 3,
"target_slot": 4,
"type": "INT"
},
{
"id": 378,
"origin_id": -10,
"origin_slot": 6,
"target_id": 37,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 379,
"origin_id": -10,
"origin_slot": 7,
"target_id": 38,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 380,
"origin_id": -10,
"origin_slot": 8,
"target_id": 39,
"target_slot": 0,
"type": "COMBO"
} }
], ],
"extra": { "extra": {
@ -1400,7 +1590,6 @@
} }
] ]
}, },
"config": {},
"extra": { "extra": {
"ds": { "ds": {
"scale": 1.14, "scale": 1.14,
@ -1409,7 +1598,6 @@
6.855893974423647 6.855893974423647
] ]
}, },
"workflowRendererVersion": "LG" "ue_links": []
}, }
"version": 0.4 }
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -267,7 +267,7 @@
"Node name for S&R": "GLSLShader" "Node name for S&R": "GLSLShader"
}, },
"widgets_values": [ "widgets_values": [
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}",
"from_input" "from_input"
] ]
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -383,7 +383,7 @@
"Node name for S&R": "GLSLShader" "Node name for S&R": "GLSLShader"
}, },
"widgets_values": [ "widgets_values": [
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n", "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n",
"from_input" "from_input"
] ]
} }

View File

@ -90,8 +90,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
class LatentPreviewMethod(enum.Enum): class LatentPreviewMethod(enum.Enum):
NoPreviews = "none" NoPreviews = "none"
@ -238,6 +238,8 @@ database_default_path = os.path.abspath(
) )
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).") parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()

View File

@ -0,0 +1,34 @@
import functools
import logging
import os
logger = logging.getLogger(__name__)
_DEFAULT_DEPLOY_ENV = "local-git"
_ENV_FILENAME = ".comfy_environment"
# Resolve the ComfyUI install directory (the parent of this `comfy/` package).
# We deliberately avoid `folder_paths.base_path` here because that is overridden
# by the `--base-directory` CLI arg to a user-supplied path, whereas the
# `.comfy_environment` marker is written by launchers/installers next to the
# ComfyUI install itself.
_COMFY_INSTALL_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
@functools.cache
def get_deploy_environment() -> str:
env_file = os.path.join(_COMFY_INSTALL_DIR, _ENV_FILENAME)
try:
with open(env_file, encoding="utf-8") as f:
# Cap the read so a malformed or maliciously crafted file (e.g.
# a single huge line with no newline) can't blow up memory.
first_line = f.readline(128).strip()
value = "".join(c for c in first_line if 32 <= ord(c) < 127)
if value:
return value
except FileNotFoundError:
pass
except Exception as e:
logger.error("Failed to read %s: %s", env_file, e)
return _DEFAULT_DEPLOY_ENV

View File

@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (PredictEvaluateCorrectEvaluate) mode (NeurIPS 2023).""" """Stochastic Adams Solver with PECE (PredictEvaluateCorrectEvaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
@torch.no_grad()
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
num_frame_per_block=1):
"""
Autoregressive video sampler: block-by-block denoising with KV cache
and flow-match re-noising for Causal Forcing / Self-Forcing models.
Requires a Causal-WAN compatible model (diffusion_model must expose
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
All AR-loop parameters are passed via the SamplerARVideo node, not read
from the checkpoint or transformer_options.
"""
extra_args = {} if extra_args is None else extra_args
model_options = extra_args.get("model_options", {})
transformer_options = model_options.get("transformer_options", {})
if x.ndim != 5:
raise ValueError(
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
)
inner_model = model.inner_model.inner_model
causal_model = inner_model.diffusion_model
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
raise TypeError(
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
"does not support this interface — choose a different sampler."
)
seed = extra_args.get("seed", 0)
bs, c, lat_t, lat_h, lat_w = x.shape
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
device = x.device
model_dtype = inner_model.get_dtype()
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
output = torch.zeros_like(x)
s_in = x.new_ones([x.shape[0]])
current_start_frame = 0
num_sigma_steps = len(sigmas) - 1
total_real_steps = num_blocks * num_sigma_steps
step_count = 0
try:
for block_idx in trange(num_blocks, disable=disable):
bf = min(num_frame_per_block, lat_t - current_start_frame)
fs, fe = current_start_frame, current_start_frame + bf
noisy_input = x[:, :, fs:fe]
ar_state = {
"start_frame": current_start_frame,
"kv_caches": kv_caches,
"crossattn_caches": crossattn_caches,
}
transformer_options["ar_state"] = ar_state
for i in range(num_sigma_steps):
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
if callback is not None:
scaled_i = step_count * num_sigma_steps // total_real_steps
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
"sigma_hat": sigmas[i], "denoised": denoised})
if sigmas[i + 1] == 0:
noisy_input = denoised
else:
sigma_next = sigmas[i + 1]
torch.manual_seed(seed + block_idx * 1000 + i)
fresh_noise = torch.randn_like(denoised)
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
for cache in kv_caches:
cache["end"] -= bf * frame_seq_len
step_count += 1
output[:, :, fs:fe] = noisy_input
for cache in kv_caches:
cache["end"] -= bf * frame_seq_len
zero_sigma = sigmas.new_zeros([1])
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
current_start_frame += bf
finally:
transformer_options.pop("ar_state", None)
return output

View File

@ -224,6 +224,7 @@ class Flux2(LatentFormat):
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2) self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
self.taesd_decoder_name = "taef2_decoder"
def process_in(self, latent): def process_in(self, latent):
return latent return latent
@ -785,3 +786,10 @@ class ZImagePixelSpace(ChromaRadiance):
No VAE encoding/decoding the model operates directly on RGB pixels. No VAE encoding/decoding the model operates directly on RGB pixels.
""" """
pass pass
class CogVideoX(LatentFormat):
latent_channels = 16
latent_dimensions = 3
def __init__(self):
self.scale_factor = 1.15258426

View File

573
comfy/ldm/cogvideo/model.py Normal file
View File

@ -0,0 +1,573 @@
# CogVideoX 3D Transformer - ported to ComfyUI native ops
# Architecture reference: diffusers CogVideoXTransformer3DModel
# Style reference: comfy/ldm/wan/model.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.patcher_extension
import comfy.ldm.common_dit
def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0):
"""Returns (cos, sin) each with shape [seq_len, dim].
Frequencies are computed at dim//2 resolution then repeat_interleaved
to full dim, matching CogVideoX's interleaved (real, imag) pair format.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim))
angles = torch.outer(pos.float(), freqs.float())
cos = angles.cos().repeat_interleave(2, dim=-1).float()
sin = angles.sin().repeat_interleave(2, dim=-1).float()
return (cos, sin)
def apply_rotary_emb(x, freqs_cos_sin):
"""Apply CogVideoX rotary embedding to query or key tensor.
x: [B, heads, seq_len, head_dim]
freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2]
Uses interleaved pair rotation (same as diffusers CogVideoX/Flux).
head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back.
"""
cos, sin = freqs_cos_sin
cos = cos[None, None, :, :].to(x.device)
sin = sin[None, None, :, :].to(x.device)
# Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half)
args = timesteps[:, None].float() * freqs[None] * scale
embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if flip_sin_to_cos:
embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None):
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale
grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale
grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale
grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij")
embed_dim_spatial = 2 * (embed_dim // 3)
embed_dim_temporal = embed_dim // 3
pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device)
pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device)
T, H, W = grid_t.shape
pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1)
pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1)
return pos_embed
def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None):
T, H, W = grid_h.shape
half_dim = embed_dim // 2
pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim)
pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim)
return torch.cat([pos_h, pos_w], dim=-1)
def _get_1d_sincos_pos_embed(embed_dim, pos, device=None):
half = embed_dim // 2
freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half)
args = pos.float().reshape(-1)[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if embed_dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class CogVideoXPatchEmbed(nn.Module):
def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920,
text_dim=4096, bias=True, sample_width=90, sample_height=60,
sample_frames=49, temporal_compression_ratio=4,
max_text_seq_length=226, spatial_interpolation_scale=1.875,
temporal_interpolation_scale=1.0, use_positional_embeddings=True,
use_learned_positional_embeddings=True,
device=None, dtype=None, operations=None):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.dim = dim
self.sample_height = sample_height
self.sample_width = sample_width
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings
if patch_size_t is None:
self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype)
else:
self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype)
self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype)
if use_positional_embeddings or use_learned_positional_embeddings:
persistent = use_learned_positional_embeddings
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None):
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
if self.patch_size_t is not None:
post_time_compression_frames = post_time_compression_frames // self.patch_size_t
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=device,
)
pos_embedding = pos_embedding.reshape(-1, self.dim)
joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
return joint_pos_embedding
def forward(self, text_embeds, image_embeds):
input_dtype = text_embeds.dtype
text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype)
batch_size, num_frames, channels, height, width = image_embeds.shape
proj_dtype = self.proj.weight.dtype
if self.patch_size_t is None:
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3)
image_embeds = image_embeds.flatten(1, 2)
else:
p = self.patch_size
p_t = self.patch_size_t
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
text_seq_length = text_embeds.shape[1]
num_image_patches = image_embeds.shape[1]
if self.use_learned_positional_embeddings:
image_pos = self.pos_embedding[
:, self.max_text_seq_length:self.max_text_seq_length + num_image_patches
].to(device=embeds.device, dtype=embeds.dtype)
else:
image_pos = get_3d_sincos_pos_embed(
self.dim,
(width // self.patch_size, height // self.patch_size),
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=embeds.device,
).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype)
# Build joint: zeros for text + sincos for image
joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
joint_pos[:, text_seq_length:] = image_pos
embeds = embeds + joint_pos
return embeds
class CogVideoXLayerNormZero(nn.Module):
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True,
device=None, dtype=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype)
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
def forward(self, hidden_states, encoder_hidden_states, temb):
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
class CogVideoXAdaLayerNorm(nn.Module):
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5,
device=None, dtype=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype)
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
def forward(self, x, temb):
temb = self.linear(self.silu(temb))
shift, scale = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class CogVideoXBlock(nn.Module):
def __init__(self, dim, num_heads, head_dim, time_dim,
eps=1e-5, ff_inner_dim=None, ff_bias=True,
device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = head_dim
self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
# Self-attention (joint text + latent)
self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
# Feed-forward (GELU approximate)
inner_dim = ff_inner_dim or dim * 4
self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype)
self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype)
def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options=None):
if transformer_options is None:
transformer_options = {}
text_seq_length = encoder_hidden_states.size(1)
# Norm & modulate
norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb)
# Joint self-attention
qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1)
b, s, _ = qkv_input.shape
n, d = self.num_heads, self.head_dim
q = self.q(qkv_input).view(b, s, n, d)
k = self.k(qkv_input).view(b, s, n, d)
v = self.v(qkv_input)
q = self.norm_q(q).view(b, s, n, d)
k = self.norm_k(k).view(b, s, n, d)
# Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim])
if image_rotary_emb is not None:
q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim]
k_img = k[:, text_seq_length:].transpose(1, 2)
q_img = apply_rotary_emb(q_img, image_rotary_emb)
k_img = apply_rotary_emb(k_img, image_rotary_emb)
q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1)
k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1)
attn_out = optimized_attention(
q.reshape(b, s, n * d),
k.reshape(b, s, n * d),
v,
heads=self.num_heads,
transformer_options=transformer_options,
)
attn_out = self.attn_out(attn_out)
attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1)
hidden_states = hidden_states + gate_msa * attn_hidden
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder
# Norm & modulate for FF
norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb)
# Feed-forward (GELU on concatenated text + latent)
ff_input = torch.cat([norm_encoder, norm_hidden], dim=1)
ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh"))
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(nn.Module):
def __init__(self,
num_attention_heads=30,
attention_head_dim=64,
in_channels=16,
out_channels=16,
flip_sin_to_cos=True,
freq_shift=0,
time_embed_dim=512,
ofs_embed_dim=None,
text_embed_dim=4096,
num_layers=30,
dropout=0.0,
attention_bias=True,
sample_width=90,
sample_height=60,
sample_frames=49,
patch_size=2,
patch_size_t=None,
temporal_compression_ratio=4,
max_text_seq_length=226,
spatial_interpolation_scale=1.875,
temporal_interpolation_scale=1.0,
use_rotary_positional_embeddings=False,
use_learned_positional_embeddings=False,
patch_bias=True,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.dtype = dtype
dim = num_attention_heads * attention_head_dim
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.max_text_seq_length = max_text_seq_length
self.use_rotary_positional_embeddings = use_rotary_positional_embeddings
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
patch_size_t=patch_size_t,
in_channels=in_channels,
dim=dim,
text_dim=text_embed_dim,
bias=patch_bias,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
device=device, dtype=torch.float32, operations=operations,
)
# 2. Time embedding
self.time_proj_dim = dim
self.time_proj_flip = flip_sin_to_cos
self.time_proj_shift = freq_shift
self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype)
self.time_embedding_act = nn.SiLU()
self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype)
# Optional OFS embedding (CogVideoX 1.5 I2V)
self.ofs_proj_dim = ofs_embed_dim
if ofs_embed_dim:
self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
self.ofs_embedding_act = nn.SiLU()
self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
else:
self.ofs_embedding_linear_1 = None
# 3. Transformer blocks
self.blocks = nn.ModuleList([
CogVideoXBlock(
dim=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
time_dim=time_embed_dim,
eps=1e-5,
device=device, dtype=dtype, operations=operations,
)
for _ in range(num_layers)
])
self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype)
# 4. Output
self.norm_out = CogVideoXAdaLayerNorm(
time_dim=time_embed_dim, dim=dim, eps=1e-5,
device=device, dtype=dtype, operations=operations,
)
if patch_size_t is None:
output_dim = patch_size * patch_size * out_channels
else:
output_dim = patch_size * patch_size * patch_size_t * out_channels
self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype)
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.temporal_compression_ratio = temporal_compression_ratio
def forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, ofs, transformer_options, **kwargs)
def _forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
# ComfyUI passes [B, C, T, H, W]
batch_size, channels, t, h, w = x.shape
# Pad to patch size (temporal + spatial), same pattern as WAN
p_t = self.patch_size_t if self.patch_size_t is not None else 1
x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size))
# CogVideoX expects [B, T, C, H, W]
x = x.permute(0, 2, 1, 3, 4)
batch_size, num_frames, channels, height, width = x.shape
# Time embedding
t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift)
t_emb = t_emb.to(dtype=x.dtype)
emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb)))
if self.ofs_embedding_linear_1 is not None and ofs is not None:
ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift)
ofs_emb = ofs_emb.to(dtype=x.dtype)
ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb)))
emb = emb + ofs_emb
# Patch embedding
hidden_states = self.patch_embed(context, x)
text_seq_length = context.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# Rotary embeddings (if used)
image_rotary_emb = None
if self.use_rotary_positional_embeddings:
post_patch_height = height // self.patch_size
post_patch_width = width // self.patch_size
if self.patch_size_t is None:
post_time = num_frames
else:
post_time = num_frames // self.patch_size_t
image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device)
# Transformer blocks
for i, block in enumerate(self.blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
hidden_states = self.norm_final(hidden_states)
# Output projection
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# Unpatchify
p = self.patch_size
p_t = self.patch_size_t
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
# Back to ComfyUI format [B, C, T, H, W] and crop padding
output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w]
return output
def _get_rotary_emb(self, h, w, t, device):
"""Compute CogVideoX 3D rotary positional embeddings.
For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode grid positions
are integer arange computed at max_size, then sliced to actual size.
For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords
scaled by spatial_interpolation_scale.
"""
d = self.attention_head_dim
dim_t = d // 4
dim_h = d // 8 * 3
dim_w = d // 8 * 3
if self.patch_size_t is not None:
# CogVideoX 1.5: "slice" mode — positions are simple integer indices
# Compute at max(sample_size, actual_size) then slice to actual
base_h = self.patch_embed.sample_height // self.patch_size
base_w = self.patch_embed.sample_width // self.patch_size
max_h = max(base_h, h)
max_w = max(base_w, w)
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
grid_t = torch.arange(t, device=device, dtype=torch.float32)
else:
# CogVideoX 1.0: "linspace" mode with interpolation scale
grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
grid_t = torch.arange(t, device=device, dtype=torch.float32)
freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t)
freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h)
freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w)
t_cos, t_sin = freqs_t
h_cos, h_sin = freqs_h
w_cos, w_sin = freqs_w
# Slice to actual size (for "slice" mode where grids may be larger)
t_cos, t_sin = t_cos[:t], t_sin[:t]
h_cos, h_sin = h_cos[:h], h_sin[:h]
w_cos, w_sin = w_cos[:w], w_sin[:w]
# Broadcast and concatenate into [T*H*W, head_dim]
t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1)
t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1)
h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1)
h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1)
w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1)
w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1)
cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1)
sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1)
return (cos, sin)

566
comfy/ldm/cogvideo/vae.py Normal file
View File

@ -0,0 +1,566 @@
# CogVideoX VAE - ported to ComfyUI native ops
# Architecture reference: diffusers AutoencoderKLCogVideoX
# Style reference: comfy/ldm/wan/vae.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
"""Causal 3D convolution with temporal padding.
Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has
a single temporal frame and no cache, the 3D conv weight is sliced to act
as a 2D conv, avoiding computation on zero-padded temporal dimensions.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
time_kernel, height_kernel, width_kernel = kernel_size
self.time_kernel_size = time_kernel
self.pad_mode = pad_mode
height_pad = (height_kernel - 1) // 2
width_pad = (width_kernel - 1) // 2
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0)
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = ops.Conv3d(
in_channels, out_channels, kernel_size,
stride=stride, dilation=dilation,
padding=(0, height_pad, width_pad),
)
def forward(self, x, conv_cache=None):
if self.pad_mode == "replicate":
x = F.pad(x, self.time_causal_padding, mode="replicate")
conv_cache = None
else:
kernel_t = self.time_kernel_size
if kernel_t > 1:
if conv_cache is None and x.shape[2] == 1:
# Fast path: single frame, no cache. All temporal padding
# frames are copies of the input (replicate-style), so the
# 3D conv reduces to a 2D conv with summed temporal kernel.
w = comfy.ops.cast_to_input(self.conv.weight, x)
b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None
w2d = w.sum(dim=2, keepdim=True)
out = F.conv3d(x, w2d, b,
self.conv.stride, self.conv.padding,
self.conv.dilation, self.conv.groups)
return out, None
cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1)
x = torch.cat(cached + [x], dim=2)
conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None
out = self.conv(x)
return out, conv_cache
def _interpolate_zq(zq, target_size):
"""Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling."""
t = target_size[0]
if t > 1 and t % 2 == 1:
z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2]))
z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2]))
return torch.cat([z_first, z_rest], dim=2)
return F.interpolate(zq, size=target_size)
class SpatialNorm3D(nn.Module):
"""Spatially conditioned normalization."""
def __init__(self, f_channels, zq_channels, groups=32):
super().__init__()
self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f, zq, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
if zq.shape[-3:] != f.shape[-3:]:
zq = _interpolate_zq(zq, f.shape[-3:])
conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
return self.norm_layer(f) * conv_y + conv_b, new_cache
class ResnetBlock3D(nn.Module):
"""3D ResNet block with optional spatial norm."""
def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32,
eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.spatial_norm_dim = spatial_norm_dim
if act_fn == "silu":
self.nonlinearity = nn.SiLU()
elif act_fn == "swish":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = nn.SiLU()
if spatial_norm_dim is None:
self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels, out_channels)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if in_channels != out_channels:
self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
def forward(self, x, temb=None, zq=None, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
residual = x
if zq is not None:
x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1"))
else:
x = self.norm1(x)
x = self.nonlinearity(x)
x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1"))
if temb is not None and hasattr(self, "temb_proj"):
x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is not None:
x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2"))
else:
x = self.norm2(x)
x = self.nonlinearity(x)
x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2"))
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return x + residual, new_cache
class Downsample3D(nn.Module):
"""3D downsampling with optional temporal compression."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
super().__init__()
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
b, c, t, h, w = x.shape
x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
if t % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
x = F.avg_pool1d(x, kernel_size=2, stride=2)
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
b, c, t, h, w = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.conv(x)
x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x
class Upsample3D(nn.Module):
"""3D upsampling with optional temporal decompression."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
super().__init__()
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
if x.shape[2] > 1 and x.shape[2] % 2 == 1:
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = F.interpolate(x_first, scale_factor=2.0)
x_rest = F.interpolate(x_rest, scale_factor=2.0)
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
elif x.shape[2] > 1:
x = F.interpolate(x, scale_factor=2.0)
else:
x = x.squeeze(2)
x = F.interpolate(x, scale_factor=2.0)
x = x[:, :, None, :, :]
else:
b, c, t, h, w = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = F.interpolate(x, scale_factor=2.0)
x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.conv(x)
x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4)
return x
class DownBlock3D(nn.Module):
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
eps=1e-6, act_fn="silu", groups=32, add_downsample=True,
compress_time=False, pad_mode="first"):
super().__init__()
self.resnets = nn.ModuleList([
ResnetBlock3D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode,
)
for i in range(num_layers)
])
self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None
def forward(self, x, temb=None, zq=None, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
if self.downsamplers is not None:
for ds in self.downsamplers:
x = ds(x)
return x, new_cache
class MidBlock3D(nn.Module):
def __init__(self, in_channels, temb_channels=0, num_layers=1,
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"):
super().__init__()
self.resnets = nn.ModuleList([
ResnetBlock3D(
in_channels=in_channels, out_channels=in_channels,
temb_channels=temb_channels, groups=groups, eps=eps,
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
)
for _ in range(num_layers)
])
def forward(self, x, temb=None, zq=None, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
return x, new_cache
class UpBlock3D(nn.Module):
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16,
add_upsample=True, compress_time=False, pad_mode="first"):
super().__init__()
self.resnets = nn.ModuleList([
ResnetBlock3D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
temb_channels=temb_channels, groups=groups, eps=eps,
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
)
for i in range(num_layers)
])
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None
def forward(self, x, temb=None, zq=None, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
if self.upsamplers is not None:
for us in self.upsamplers:
x = us(x)
return x, new_cache
class Encoder3D(nn.Module):
def __init__(self, in_channels=3, out_channels=16,
block_out_channels=(128, 256, 256, 512),
layers_per_block=3, act_fn="silu",
eps=1e-6, groups=32, pad_mode="first",
temporal_compression_ratio=4):
super().__init__()
temporal_compress_level = int(np.log2(temporal_compression_ratio))
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final = i == len(block_out_channels) - 1
compress_time = i < temporal_compress_level
self.down_blocks.append(DownBlock3D(
in_channels=input_channel, out_channels=output_channel,
temb_channels=0, num_layers=layers_per_block,
eps=eps, act_fn=act_fn, groups=groups,
add_downsample=not is_final, compress_time=compress_time,
))
self.mid_block = MidBlock3D(
in_channels=block_out_channels[-1], temb_channels=0,
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
)
self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
def forward(self, x, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in"))
for i, block in enumerate(self.down_blocks):
key = f"down_block_{i}"
x, new_cache[key] = block(x, None, None, conv_cache.get(key))
x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block"))
x = self.norm_out(x)
x = self.conv_act(x)
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
return x, new_cache
class Decoder3D(nn.Module):
def __init__(self, in_channels=16, out_channels=3,
block_out_channels=(128, 256, 256, 512),
layers_per_block=3, act_fn="silu",
eps=1e-6, groups=32, pad_mode="first",
temporal_compression_ratio=4):
super().__init__()
reversed_channels = list(reversed(block_out_channels))
temporal_compress_level = int(np.log2(temporal_compression_ratio))
self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode)
self.mid_block = MidBlock3D(
in_channels=reversed_channels[0], temb_channels=0,
num_layers=2, eps=eps, act_fn=act_fn, groups=groups,
spatial_norm_dim=in_channels, pad_mode=pad_mode,
)
self.up_blocks = nn.ModuleList()
output_channel = reversed_channels[0]
for i in range(len(block_out_channels)):
prev_channel = output_channel
output_channel = reversed_channels[i]
is_final = i == len(block_out_channels) - 1
compress_time = i < temporal_compress_level
self.up_blocks.append(UpBlock3D(
in_channels=prev_channel, out_channels=output_channel,
temb_channels=0, num_layers=layers_per_block + 1,
eps=eps, act_fn=act_fn, groups=groups,
spatial_norm_dim=in_channels,
add_upsample=not is_final, compress_time=compress_time,
))
self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode)
def forward(self, sample, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block"))
for i, block in enumerate(self.up_blocks):
key = f"up_block_{i}"
x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key))
x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out"))
x = self.conv_act(x)
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
return x, new_cache
class AutoencoderKLCogVideoX(nn.Module):
"""CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper.
Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run
on the full (low-res) tensor, then the expensive spatial-only up_blocks +
norm_out + conv_out are processed in small temporal chunks with conv_cache
carrying causal state between chunks. This keeps peak VRAM proportional to
chunk_size rather than total frame count.
"""
def __init__(self,
in_channels=3, out_channels=3,
block_out_channels=(128, 256, 256, 512),
latent_channels=16, layers_per_block=3,
act_fn="silu", eps=1e-6, groups=32,
temporal_compression_ratio=4,
):
super().__init__()
self.latent_channels = latent_channels
self.temporal_compression_ratio = temporal_compression_ratio
self.encoder = Encoder3D(
in_channels=in_channels, out_channels=latent_channels,
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
act_fn=act_fn, eps=eps, groups=groups,
temporal_compression_ratio=temporal_compression_ratio,
)
self.decoder = Decoder3D(
in_channels=latent_channels, out_channels=out_channels,
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
act_fn=act_fn, eps=eps, groups=groups,
temporal_compression_ratio=temporal_compression_ratio,
)
self.num_latent_frames_batch_size = 2
self.num_sample_frames_batch_size = 8
def encode(self, x):
t = x.shape[2]
frame_batch = self.num_sample_frames_batch_size
remainder = t % frame_batch
conv_cache = None
enc = []
# Process remainder frames first so only the first chunk can have an
# odd temporal dimension — where Downsample3D's first-frame-special
# handling in temporal compression is actually correct.
if remainder > 0:
chunk, conv_cache = self.encoder(x[:, :, :remainder], conv_cache=conv_cache)
enc.append(chunk.to(x.device))
for start in range(remainder, t, frame_batch):
chunk, conv_cache = self.encoder(x[:, :, start:start + frame_batch], conv_cache=conv_cache)
enc.append(chunk.to(x.device))
enc = torch.cat(enc, dim=2)
mean, _ = enc.chunk(2, dim=1)
return mean
def decode(self, z):
return self._decode_rolling(z)
def _decode_batched(self, z):
"""Original batched decode - processes 2 latent frames through full decoder."""
t = z.shape[2]
frame_batch = self.num_latent_frames_batch_size
num_batches = max(t // frame_batch, 1)
conv_cache = None
dec = []
for i in range(num_batches):
remaining = t % frame_batch
start = frame_batch * i + (0 if i == 0 else remaining)
end = frame_batch * (i + 1) + remaining
chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache)
dec.append(chunk.cpu())
return torch.cat(dec, dim=2).to(z.device)
def _decode_rolling(self, z):
"""Rolling decode - processes low-res layers on full tensor, then rolls
through expensive high-res layers in temporal chunks."""
decoder = self.decoder
device = z.device
# Determine which up_blocks have temporal upsample vs spatial-only.
# Temporal up_blocks are cheap (low res), spatial-only are expensive.
temporal_compress_level = int(np.log2(self.temporal_compression_ratio))
split_at = temporal_compress_level # first N up_blocks do temporal upsample
# Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res)
x, _ = decoder.conv_in(z)
x, _ = decoder.mid_block(x, None, z)
for i in range(split_at):
x, _ = decoder.up_blocks[i](x, None, z)
# Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks
remaining_blocks = list(range(split_at, len(decoder.up_blocks)))
chunk_size = 4 # pixel frames per chunk through high-res layers
t_expanded = x.shape[2]
if t_expanded <= chunk_size or len(remaining_blocks) == 0:
# Small enough to process in one go
for i in remaining_blocks:
x, _ = decoder.up_blocks[i](x, None, z)
x, _ = decoder.norm_out(x, z)
x = decoder.conv_act(x)
x, _ = decoder.conv_out(x)
return x
# Expand z temporally once to match Phase 2's time dimension.
# z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB
# for the old approach of pre-interpolating to every pixel resolution).
z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4]))
# Process in temporal chunks, interpolating spatially per-chunk to avoid
# allocating full [B, C, t_expanded, H, W] tensors at each resolution.
dec_out = []
conv_caches = {}
for chunk_start in range(0, t_expanded, chunk_size):
chunk_end = min(chunk_start + chunk_size, t_expanded)
x_chunk = x[:, :, chunk_start:chunk_end]
z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end]
z_spatial_cache = {}
for i in remaining_blocks:
block = decoder.up_blocks[i]
cache_key = f"up_block_{i}"
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
if hw_key not in z_spatial_cache:
if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]:
z_spatial_cache[hw_key] = z_t_chunk
else:
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key))
conv_caches[cache_key] = new_cache
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
if hw_key not in z_spatial_cache:
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out"))
conv_caches["norm_out"] = new_cache
x_chunk = decoder.conv_act(x_chunk)
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
conv_caches["conv_out"] = new_cache
dec_out.append(x_chunk.cpu())
del z_spatial_cache
del x, z_time_expanded
return torch.cat(dec_out, dim=2).to(device)

View File

@ -15,7 +15,7 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim scale = torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim
omega = 1.0 / (theta**scale) omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega) out = torch.einsum("...n,d->...nd", pos.to(device), omega)
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0) out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
return out.to(dtype=torch.float32, device=pos.device) return out.to(dtype=torch.float32, device=pos.device)
@ -118,8 +118,6 @@ class ErnieImageAttention(nn.Module):
query = apply_rotary_emb(query, image_rotary_emb) query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb)
query, key = query.to(x.dtype), key.to(x.dtype)
q_flat = query.reshape(B, S, -1) q_flat = query.reshape(B, S, -1)
k_flat = key.reshape(B, S, -1) k_flat = key.reshape(B, S, -1)
@ -161,16 +159,16 @@ class ErnieImageSharedAdaLNBlock(nn.Module):
residual = x residual = x
x_norm = self.adaLN_sa_ln(x) x_norm = self.adaLN_sa_ln(x)
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype) x_norm = x_norm * (1 + scale_msa) + shift_msa
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb) attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype) x = residual + gate_msa * attn_out
residual = x residual = x
x_norm = self.adaLN_mlp_ln(x) x_norm = self.adaLN_mlp_ln(x)
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype) x_norm = x_norm * (1 + scale_mlp) + shift_mlp
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype) return residual + gate_mlp * self.mlp(x_norm)
class ErnieImageAdaLNContinuous(nn.Module): class ErnieImageAdaLNContinuous(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None): def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
@ -183,7 +181,7 @@ class ErnieImageAdaLNContinuous(nn.Module):
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
scale, shift = self.linear(conditioning).chunk(2, dim=-1) scale, shift = self.linear(conditioning).chunk(2, dim=-1)
x = self.norm(x) x = self.norm(x)
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1))
return x return x
class ErnieImageModel(nn.Module): class ErnieImageModel(nn.Module):
@ -279,7 +277,7 @@ class ErnieImageModel(nn.Module):
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype) rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
del image_ids, text_ids del image_ids, text_ids
sample = self.time_proj(timesteps.to(dtype)).to(self.time_embedding.linear_1.weight.dtype) sample = self.time_proj(timesteps).to(dtype)
c = self.time_embedding(sample) c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [

View File

@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.model_prefetch
class CompressedTimestep: class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing.""" """Store video timestep embeddings in compressed form using per-frame indexing."""
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
"""Process transformer blocks for LTXAV.""" """Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
# Process transformer blocks # Process transformer blocks
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
a_prompt_timestep=a_prompt_timestep, a_prompt_timestep=a_prompt_timestep,
) )
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
return [vx, ax] return [vx, ax]
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):

View File

@ -4,9 +4,6 @@ import math
import torch import torch
import torchaudio import torchaudio
import comfy.model_management
import comfy.model_patcher
import comfy.utils as utils
from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
@ -43,30 +40,6 @@ class AudioVAEComponentConfig:
return cls(autoencoder=audio_config, vocoder=vocoder_config) return cls(autoencoder=audio_config, vocoder=vocoder_config)
class ModelDeviceManager:
"""Manages device placement and GPU residency for the composed model."""
def __init__(self, module: torch.nn.Module):
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.vae_offload_device()
self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device)
def ensure_model_loaded(self) -> None:
comfy.model_management.free_memory(
self.patcher.model_size(),
self.patcher.load_device,
)
comfy.model_management.load_model_gpu(self.patcher)
def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(self.patcher.load_device)
@property
def load_device(self):
return self.patcher.load_device
class AudioLatentNormalizer: class AudioLatentNormalizer:
"""Applies per-channel statistics in patch space and restores original layout.""" """Applies per-channel statistics in patch space and restores original layout."""
@ -132,23 +105,17 @@ class AudioPreprocessor:
class AudioVAE(torch.nn.Module): class AudioVAE(torch.nn.Module):
"""High-level Audio VAE wrapper exposing encode and decode entry points.""" """High-level Audio VAE wrapper exposing encode and decode entry points."""
def __init__(self, state_dict: dict, metadata: dict): def __init__(self, metadata: dict):
super().__init__() super().__init__()
component_config = AudioVAEComponentConfig.from_metadata(metadata) component_config = AudioVAEComponentConfig.from_metadata(metadata)
vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True)
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
if "bwe" in component_config.vocoder: if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder) self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else: else:
self.vocoder = Vocoder(config=component_config.vocoder) self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)
autoencoder_config = self.autoencoder.get_config() autoencoder_config = self.autoencoder.get_config()
self.normalizer = AudioLatentNormalizer( self.normalizer = AudioLatentNormalizer(
AudioPatchifier( AudioPatchifier(
@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module):
n_fft=autoencoder_config["n_fft"], n_fft=autoencoder_config["n_fft"],
) )
self.device_manager = ModelDeviceManager(self) def encode(self, audio, sample_rate=44100) -> torch.Tensor:
def encode(self, audio: dict) -> torch.Tensor:
"""Encode a waveform dictionary into normalized latent tensors.""" """Encode a waveform dictionary into normalized latent tensors."""
waveform = audio["waveform"] waveform = audio
waveform_sample_rate = audio["sample_rate"] waveform_sample_rate = sample_rate
input_device = waveform.device input_device = waveform.device
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
waveform = self.device_manager.move_to_load_device(waveform)
expected_channels = self.autoencoder.encoder.in_channels expected_channels = self.autoencoder.encoder.in_channels
if waveform.shape[1] != expected_channels: if waveform.shape[1] != expected_channels:
if waveform.shape[1] == 1: if waveform.shape[1] == 1:
@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module):
) )
mel_spec = self.preprocessor.waveform_to_mel( mel_spec = self.preprocessor.waveform_to_mel(
waveform, waveform_sample_rate, device=self.device_manager.load_device waveform, waveform_sample_rate, device=waveform.device
) )
latents = self.autoencoder.encode(mel_spec) latents = self.autoencoder.encode(mel_spec)
@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module):
"""Decode normalized latent tensors into an audio waveform.""" """Decode normalized latent tensors into an audio waveform."""
original_shape = latents.shape original_shape = latents.shape
# Ensure that Audio VAE is loaded on the correct device.
self.device_manager.ensure_model_loaded()
latents = self.device_manager.move_to_load_device(latents)
latents = self.normalizer.denormalize(latents) latents = self.normalizer.denormalize(latents)
target_shape = self.target_shape_from_latents(original_shape) target_shape = self.target_shape_from_latents(original_shape)
mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) mel_spec = self.autoencoder.decode(latents, target_shape=target_shape)
waveform = self.run_vocoder(mel_spec) waveform = self.run_vocoder(mel_spec)
return self.device_manager.move_to_load_device(waveform) return waveform
def target_shape_from_latents(self, latents_shape): def target_shape_from_latents(self, latents_shape):
batch, _, time, _ = latents_shape batch, _, time, _ = latents_shape

View File

@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads h = heads
if skip_reshape: if skip_reshape:
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape b, _, dim_head = query.shape
dim_head //= heads dim_head //= heads
if "scale" in kwargs:
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape: if skip_reshape:
query = query.reshape(b * heads, -1, dim_head) query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head) value = value.reshape(b * heads, -1, dim_head)
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape: if skip_reshape:
q, k, v = map( q, k, v = map(
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3: if mask.ndim == 3:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b: if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape: if not skip_output_reshape:
out = ( out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT],
attn_mask=m, attn_mask=m,
dropout_p=0.0, is_causal=False dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out return out

View File

@ -34,6 +34,16 @@ class TimestepBlock(nn.Module):
#This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index" #This is needed because accelerate makes a copy of transformer_options which breaks "transformer_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None):
for layer in ts: for layer in ts:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
if isinstance(layer, VideoResBlock): if isinstance(layer, VideoResBlock):
x = layer(x, emb, num_video_frames, image_only_indicator) x = layer(x, emb, num_video_frames, image_only_indicator)
elif isinstance(layer, TimestepBlock): elif isinstance(layer, TimestepBlock):
@ -49,15 +59,6 @@ def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, out
elif isinstance(layer, Upsample): elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape) x = layer(x, output_shape=output_shape)
else: else:
if "patches" in transformer_options and "forward_timestep_embed_patch" in transformer_options["patches"]:
found_patched = False
for class_type, handler in transformer_options["patches"]["forward_timestep_embed_patch"]:
if isinstance(layer, class_type):
x = handler(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator)
found_patched = True
break
if found_patched:
continue
x = layer(x) x = layer(x)
return x return x
@ -894,6 +895,12 @@ class UNetModel(nn.Module):
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator) h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle') h = apply_control(h, control, 'middle')
if "middle_block_after_patch" in transformer_patches:
patch = transformer_patches["middle_block_after_patch"]
for p in patch:
out = p({"h": h, "x": x, "emb": emb, "context": context, "y": y,
"timesteps": timesteps, "transformer_options": transformer_options})
h = out["h"]
for id, module in enumerate(self.output_blocks): for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id) transformer_options["block"] = ("output", id)
@ -905,8 +912,9 @@ class UNetModel(nn.Module):
for p in patch: for p in patch:
h, hsp = p(h, hsp, transformer_options) h, hsp = p(h, hsp, transformer_options)
h = th.cat([h, hsp], dim=1) if hsp is not None:
del hsp h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0: if len(hs) > 0:
output_shape = hs[-1].shape output_shape = hs[-1].shape
else: else:

596
comfy/ldm/sam3/detector.py Normal file
View File

@ -0,0 +1,596 @@
# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops import roi_align
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker
from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__
from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine
TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker}
from comfy.ops import cast_to_input
def box_cxcywh_to_xyxy(x):
cx, cy, w, h = x.unbind(-1)
return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1)
def gen_sineembed_for_position(pos_tensor, num_feats=256):
"""Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats)."""
assert num_feats % 2 == 0
hdim = num_feats // 2
freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim)
embeds = []
for c in range(pos_tensor.shape[-1]):
raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs
embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2))
return torch.cat(embeds, dim=-1).to(pos_tensor.dtype)
class SplitMHA(nn.Module):
"""Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight)."""
def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
def forward(self, q_input, k_input=None, v_input=None, mask=None):
q = self.q_proj(q_input)
if k_input is None:
k = self.k_proj(q_input)
v = self.v_proj(q_input)
else:
k = self.k_proj(k_input)
v = self.v_proj(v_input if v_input is not None else k_input)
if mask is not None and mask.ndim == 2:
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
dtype = q.dtype # manual_cast may produce mixed dtypes
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False)
return self.out_proj(out)
class MLPWithNorm(nn.Module):
"""MLP with residual connection and output LayerNorm."""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList([
operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype)
for i in range(num_layers)
])
self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype)
self.residual = residual and (input_dim == output_dim)
def forward(self, x):
orig = x
for i, layer in enumerate(self.layers):
x = layer(x)
if i < len(self.layers) - 1:
x = F.relu(x)
if self.residual:
x = x + orig
return self.out_norm(x)
class EncoderLayer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, x, pos, text_memory=None, text_mask=None):
normed = self.norm1(x)
q_k = normed + pos
x = x + self.self_attn(q_k, q_k, normed)
if text_memory is not None:
normed = self.norm2(x)
x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask)
normed = self.norm3(x)
x = x + self.linear2(F.relu(self.linear1(normed)))
return x
class TransformerEncoder(nn.Module):
"""Checkpoint: transformer.encoder.layers.N.*"""
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
def forward(self, x, pos, text_memory=None, text_mask=None):
for layer in self.layers:
x = layer(x, pos, text_memory, text_mask)
return x
class DecoderLayer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype)
def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None):
q_k = x + x_pos
x = self.norm2(x + self.self_attn(q_k, q_k, x))
if text_memory is not None:
x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask))
x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias))
x = self.norm3(x + self.linear2(F.relu(self.linear1(x))))
return x
class TransformerDecoder(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6,
num_queries=200, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.num_queries = num_queries
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype)
self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes
self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each)
self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations)
self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations)
self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype)
self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations)
self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
@staticmethod
def _inverse_sigmoid(x):
return torch.log(x / (1 - x + 1e-6) + 1e-6)
def _compute_box_rpb(self, ref_points, H, W):
"""Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias."""
boxes_xyxy = box_cxcywh_to_xyxy(ref_points)
B, Q, _ = boxes_xyxy.shape
coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H
coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W
deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2]
deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2]
log2_8 = float(math.log2(8))
def log_scale(d):
return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8
rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype))
rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype))
bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2)
pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype)
return torch.cat([pres_bias, bias], dim=2)
def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72):
B = memory.shape[0]
tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1)
presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1)
ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid()
for layer_idx, layer in enumerate(self.layers):
query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model))
tgt_with_pres = torch.cat([presence_out, tgt], dim=1)
pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1)
tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos,
text_memory, text_mask, self._compute_box_rpb(ref_points, H, W))
presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:]
if layer_idx < len(self.layers) - 1:
ref_inv = self._inverse_sigmoid(ref_points)
ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach()
query_out = self.norm(tgt)
ref_inv = self._inverse_sigmoid(ref_points)
boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid()
presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1)
return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence}
class Transformer(nn.Module):
def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6,
num_queries=200, device=None, dtype=None, operations=None):
super().__init__()
self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations)
self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations)
class GeometryEncoder(nn.Module):
def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.roi_size = roi_size
self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype)
self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype)
self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype)
self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype)
self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype)
self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype)
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.encode = nn.ModuleList([
EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
def _encode_points(self, coords, labels, img_feat_2d):
"""Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized."""
B, N, _ = coords.shape
embed = self.points_direct_project(coords)
# Pool features from backbone at point locations via grid_sample
grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1]
sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1]
embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C]
# Positional encoding of coordinates
x, y = coords[:, :, 0], coords[:, :, 1] # [B, N]
pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten())
enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1)
embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed))
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
return embed
def _encode_boxes(self, boxes, labels, img_feat_2d):
"""Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh."""
B, N, _ = boxes.shape
embed = self.boxes_direct_project(boxes)
# ROI align from backbone at box regions
H, W = img_feat_2d.shape[-2:]
boxes_xyxy = box_cxcywh_to_xyxy(boxes)
scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device)
boxes_scaled = boxes_xyxy * scale
sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size)
proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C]
embed = embed + proj
# Positional encoding of box center + size
cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3]
enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten())
enc = enc.view(B, N, -1)
embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed))
embed = embed + cast_to_input(self.label_embed(labels.long()), embed)
return embed
def forward(self, points=None, boxes=None, image_features=None):
"""Encode geometry prompts. image_features: [B, HW, C] flattened backbone features."""
# Prepare 2D image features for pooling
img_feat_2d = None
if image_features is not None:
B = image_features.shape[0]
HW, C = image_features.shape[1], image_features.shape[2]
hw = int(math.sqrt(HW))
img_normed = self.img_pre_norm(image_features)
img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw)
embeddings = []
if points is not None:
coords, labels = points
embeddings.append(self._encode_points(coords, labels, img_feat_2d))
if boxes is not None:
B = boxes.shape[0]
box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device)
embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d))
if not embeddings:
return None
geo = torch.cat(embeddings, dim=1)
geo = self.norm(geo)
if image_features is not None:
for layer in self.encode:
geo = layer(geo, torch.zeros_like(geo), image_features)
geo = self.encode_norm(geo)
return self.final_proj(geo)
class PixelDecoder(nn.Module):
"""Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation."""
def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None):
super().__init__()
self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)])
self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)])
def forward(self, backbone_features):
prev = backbone_features[-1]
for i, feat in enumerate(backbone_features[:-1][::-1]):
prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest"))))
return prev
class MaskPredictor(nn.Module):
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
super().__init__()
self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations)
def forward(self, query_embeddings, pixel_features):
mask_embed = self.mask_embed(query_embeddings)
return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features)
class SegmentationHead(nn.Module):
def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None):
super().__init__()
self.d_model = d_model
self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations)
self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations)
self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype)
self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype)
def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None):
if encoder_hidden_states is not None and prompt is not None:
enc_normed = self.cross_attn_norm(encoder_hidden_states)
enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask)
encoder_hidden_states = enc_cross + encoder_hidden_states
if encoder_hidden_states is not None:
B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1]
encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W)
backbone_features = list(backbone_features)
backbone_features[-1] = encoder_visual
pixel_features = self.pixel_decoder(backbone_features)
instance_features = self.instance_seg_head(pixel_features)
masks = self.mask_predictor(query_embeddings, instance_features)
return masks
class DotProductScoring(nn.Module):
def __init__(self, d_model=256, device=None, dtype=None, operations=None):
super().__init__()
self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype)
self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations)
self.scale = 1.0 / (d_model ** 0.5)
def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None):
prompt = self.prompt_mlp(prompt_embeddings)
if prompt_mask is not None:
weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype)
pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1)
else:
pooled = prompt.mean(dim=1)
hs = self.hs_proj(query_embeddings)
pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype)
scores = torch.matmul(hs, pp)
return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1)
class SAM3Detector(nn.Module):
def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
image_model = kwargs.pop("image_model", "SAM3")
for k in ("num_heads", "num_head_channels"):
kwargs.pop(k, None)
multiplex = image_model == "SAM31"
# SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0)
self.scalp = 0 if multiplex else 1
self.backbone = nn.ModuleDict({
"vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs),
"language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}),
})
self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations)
self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations)
self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations)
self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations)
def _get_backbone_features(self, images):
"""Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions)."""
bb = self.backbone["vision_backbone"]
if bb.multiplex:
all_f, all_p, tf, tp = bb(images, tracker_mode="propagation")
else:
all_f, all_p, tf, tp = bb(images, need_tracker=True)
return all_f, all_p, tf, tp
@staticmethod
def _run_geo_layer(layer, x, memory, memory_pos):
x = x + layer.self_attn(layer.norm1(x))
x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory)
x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x))))
return x
def _detect(self, features, positions, text_embeddings=None, text_mask=None,
points=None, boxes=None):
"""Shared detection: geometry encoding, transformer, scoring, segmentation."""
B = features[0].shape[0]
# Scalp for encoder (use top-level feature), but keep all levels for segmentation head
seg_features = features
if self.scalp > 0:
features = features[:-self.scalp]
positions = positions[:-self.scalp]
enc_feat, enc_pos = features[-1], positions[-1]
_, _, H, W = enc_feat.shape
img_flat = enc_feat.flatten(2).permute(0, 2, 1)
pos_flat = enc_pos.flatten(2).permute(0, 2, 1)
has_prompts = text_embeddings is not None or points is not None or boxes is not None
if has_prompts:
geo_enc = self.geometry_encoder
geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat)
geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1)))
for layer in geo_enc.encode:
geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat)
geo_cls = geo_enc.encode_norm(geo_cls)
if text_embeddings is not None and text_embeddings.shape[0] != B:
text_embeddings = text_embeddings.expand(B, -1, -1)
if text_mask is not None and text_mask.shape[0] != B:
text_mask = text_mask.expand(B, -1)
parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None]
text_embeddings = torch.cat(parts, dim=1)
n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0)
if text_mask is not None:
text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1)
else:
text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device)
memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask)
dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W)
query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"]
if text_embeddings is not None:
scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask)
else:
scores = torch.zeros(B, query_out.shape[1], device=query_out.device)
masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask)
return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out
def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None):
features, positions, _, _ = self._get_backbone_features(images)
if text_embeddings is not None:
text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings)
if text_mask is not None:
text_mask = text_mask.bool()
boxes_xyxy, scores, masks, dec_out = self._detect(
features, positions, text_embeddings, text_mask, points, boxes)
if orig_size is not None:
oh, ow = orig_size
boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype)
masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False)
return {
"boxes": boxes_xyxy,
"scores": scores,
"masks": masks,
"presence": dec_out.get("presence"),
}
def forward_from_trunk(self, trunk_out, text_embeddings, text_mask):
"""Run detection using a pre-computed ViTDet trunk output.
text_embeddings must already be resized through language_backbone.resizer.
Returns dict with boxes (normalized xyxy), scores, masks at detector resolution.
"""
bb = self.backbone["vision_backbone"]
features = [conv(trunk_out) for conv in bb.convs]
positions = [cast_to_input(bb.position_encoding(f), f) for f in features]
if text_mask is not None:
text_mask = text_mask.bool()
boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask)
return {"boxes": boxes_xyxy, "scores": scores, "masks": masks}
class SAM3Model(nn.Module):
def __init__(self, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
image_model = kwargs.get("image_model", "SAM3")
tracker_cls = TRACKER_CLASSES[image_model]
self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs)
self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs)
def forward(self, images, **kwargs):
return self.detector(images, **kwargs)
def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None):
"""Interactive segmentation using SAM decoder with point/box/mask prompts.
Args:
images: [B, 3, 1008, 1008] preprocessed images
point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space
box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space
mask_inputs: [B, 1, H, W] coarse mask logits to refine
Returns:
[B, 1, image_size, image_size] high-res mask logits
"""
bb = self.detector.backbone["vision_backbone"]
if bb.multiplex:
_, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive")
else:
_, _, tracker_features, tracker_positions = bb(images, need_tracker=True)
if self.detector.scalp > 0:
tracker_features = tracker_features[:-self.detector.scalp]
tracker_positions = tracker_positions[:-self.detector.scalp]
high_res = list(tracker_features[:-1])
backbone_feat = tracker_features[-1]
B, C, H, W = backbone_feat.shape
# Add no-memory embedding (init frame path)
no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None)
if no_mem is None:
no_mem = getattr(self.tracker, 'no_mem_embed', None)
if no_mem is not None:
feat_flat = backbone_feat.flatten(2).permute(0, 2, 1)
feat_flat = feat_flat + cast_to_input(no_mem, feat_flat)
backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2)
num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1)
_, high_res_masks, _, _ = self.tracker._forward_sam_heads(
backbone_features=backbone_feat,
point_inputs=point_inputs,
mask_inputs=mask_inputs,
box_inputs=box_inputs,
high_res_features=high_res,
multimask_output=(0 < num_pts <= 1),
)
return high_res_masks
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1):
"""Track video with optional per-frame text-prompted detection."""
bb = self.detector.backbone["vision_backbone"]
def backbone_fn(frame, frame_idx=None):
trunk_out = bb.trunk(frame)
if bb.multiplex:
_, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True)
else:
_, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True)
return tf, tp, trunk_out
detect_fn = None
if text_prompts:
resizer = self.detector.backbone["language_backbone"]["resizer"]
resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts]
def detect_fn(trunk_out):
all_scores, all_masks = [], []
for emb, mask in resized:
det = self.detector.forward_from_trunk(trunk_out, emb, mask)
all_scores.append(det["scores"])
all_masks.append(det["masks"])
return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)}
if hasattr(self.tracker, 'track_video_with_detection'):
return self.tracker.track_video_with_detection(
backbone_fn, images, initial_masks, detect_fn,
new_det_thresh=new_det_thresh, max_objects=max_objects,
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
# SAM3 (non-multiplex) — no detection support, requires initial masks
if initial_masks is None:
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)

425
comfy/ldm/sam3/sam.py Normal file
View File

@ -0,0 +1,425 @@
# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.flux.layers import EmbedND
from comfy.ops import cast_to_input
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None):
super().__init__()
dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim]
self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)])
self.sigmoid_output = sigmoid_output
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x)
return torch.sigmoid(x) if self.sigmoid_output else x
class SAMAttention(nn.Module):
def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
internal_dim = embedding_dim // downsample_rate
kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype)
def forward(self, q, k, v):
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
class TwoWayAttentionBlock(nn.Module):
def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None):
super().__init__()
self.skip_first_layer_pe = skip_first_layer_pe
self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations)
self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype))
self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, queries, keys, query_pe, key_pe):
if self.skip_first_layer_pe:
queries = self.norm1(self.self_attn(queries, queries, queries))
else:
q = queries + query_pe
queries = self.norm1(queries + self.self_attn(q, q, queries))
q, k = queries + query_pe, keys + key_pe
queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys))
queries = self.norm3(queries + self.mlp(queries))
q, k = queries + query_pe, keys + key_pe
keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries))
return queries, keys
class TwoWayTransformer(nn.Module):
def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate,
skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations)
for i in range(depth)
])
self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations)
self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype)
def forward(self, image_embedding, image_pe, point_embedding):
queries, keys = point_embedding, image_embedding
for layer in self.layers:
queries, keys = layer(queries, keys, point_embedding, image_pe)
q, k = queries + point_embedding, keys + image_pe
queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys))
return queries, keys
class PositionEmbeddingRandom(nn.Module):
"""Fourier feature positional encoding with random gaussian projection."""
def __init__(self, num_pos_feats=64, scale=None):
super().__init__()
self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats))
def _encode(self, normalized_coords):
"""Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32."""
orig_dtype = normalized_coords.dtype
proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32)
projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix
return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype)
def forward(self, size, device=None):
h, w = size
dev = device if device is not None else self.positional_encoding_gaussian_matrix.device
ones = torch.ones((h, w), device=dev, dtype=torch.float32)
norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1)
return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0)
def forward_with_coords(self, pixel_coords, image_size):
norm = pixel_coords.clone()
norm[:, :, 0] /= image_size[1]
norm[:, :, 1] /= image_size[0]
return self._encode(norm)
# ViTDet backbone + FPN neck
def window_partition(x: torch.Tensor, window_size: int):
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw):
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0):
"""Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2]."""
t = torch.arange(end_x * end_y, dtype=torch.float32)
ids = torch.stack([(t % end_x) * scale_pos,
torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1)
return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0))
class _ViTMLP(nn.Module):
def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None):
super().__init__()
hidden = int(dim * mlp_ratio)
self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype)
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class Attention(nn.Module):
"""ViTDet multi-head attention with fused QKV projection."""
def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.use_rope = use_rope
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x, freqs_cis=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
if self.use_rope and freqs_cis is not None:
q, k = apply_rope(q, k, freqs_cis)
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations)
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations)
def forward(self, x, freqs_cis=None):
shortcut = x
x = self.norm1(x)
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = x.view(x.shape[0], self.window_size * self.window_size, -1)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(-1, self.window_size, self.window_size, x.shape[-1])
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
else:
B, H, W, C = x.shape
x = x.view(B, H * W, C)
x = self.attn(x, freqs_cis=freqs_cis)
x = x.view(B, H, W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchEmbed(nn.Module):
def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None):
super().__init__()
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype)
def forward(self, x):
return self.proj(x)
class ViTDet(nn.Module):
def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24,
global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_heads = num_heads
self.global_att_blocks = set(global_att_blocks)
self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations)
num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype))
self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype)
grid_size = img_size // patch_size
pretrain_grid = pretrain_img_size // patch_size
self.blocks = nn.ModuleList()
for i in range(depth):
is_global = i in self.global_att_blocks
self.blocks.append(Block(
embed_dim, num_heads, mlp_ratio, qkv_bias,
window_size=0 if is_global else window_size,
use_rope=use_rope,
device=device, dtype=dtype, operations=operations,
))
if use_rope:
rope_scale = pretrain_grid / grid_size
self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False)
self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False)
else:
self.freqs_cis = None
self.freqs_cis_window = None
def _get_pos_embed(self, num_tokens):
pos = self.pos_embed
if pos.shape[1] == num_tokens:
return pos
cls_pos = pos[:, :1]
spatial_pos = pos[:, 1:]
old_size = int(math.sqrt(spatial_pos.shape[1]))
new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size
spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2)
tiles_h = new_size // old_size + 1
tiles_w = new_size // old_size + 1
tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size]
tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
return torch.cat([cls_pos, tiled], dim=1)
def forward(self, x):
x = self.patch_embed(x)
B, C, Hp, Wp = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C)
pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x)
x = x + pos[:, 1:Hp * Wp + 1]
x = x.view(B, Hp, Wp, C)
x = self.ln_pre(x)
freqs_cis_global = self.freqs_cis
freqs_cis_win = self.freqs_cis_window
if freqs_cis_global is not None:
freqs_cis_global = cast_to_input(freqs_cis_global, x)
if freqs_cis_win is not None:
freqs_cis_win = cast_to_input(freqs_cis_win, x)
for block in self.blocks:
fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global
x = block(x, freqs_cis=fc)
return x.permute(0, 3, 1, 2)
class FPNScaleConv(nn.Module):
def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None):
super().__init__()
if scale == 4.0:
self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 4
elif scale == 2.0:
self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype)
proj_in = in_dim // 2
elif scale == 1.0:
proj_in = in_dim
elif scale == 0.5:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
proj_in = in_dim
self.scale = scale
self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype)
self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype)
def forward(self, x):
if self.scale == 4.0:
x = F.gelu(self.dconv_2x2_0(x))
x = self.dconv_2x2_1(x)
elif self.scale == 2.0:
x = self.dconv_2x2(x)
elif self.scale == 0.5:
x = self.pool(x)
x = self.conv_1x1(x)
x = self.conv_3x3(x)
return x
class PositionEmbeddingSine(nn.Module):
"""2D sinusoidal position encoding (DETR-style) with result caching."""
def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None):
super().__init__()
assert num_pos_feats % 2 == 0
self.half_dim = num_pos_feats // 2
self.temperature = temperature
self.normalize = normalize
self.scale = scale if scale is not None else 2 * math.pi
self._cache = {}
def _sincos(self, vals):
"""Encode 1D values to interleaved sin/cos features."""
freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim)
raw = vals[..., None] * self.scale / freqs
return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2)
def _encode_xy(self, x, y):
"""Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim]."""
dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim)
pos_x = x[:, None] * self.scale / dim_t
pos_y = y[:, None] * self.scale / dim_t
pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1)
pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1)
return pos_x, pos_y
def encode_boxes(self, cx, cy, w, h):
"""Encode box center + size to [N, d_model+2] features."""
pos_x, pos_y = self._encode_xy(cx, cy)
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
def forward(self, x):
B, C, H, W = x.shape
key = (H, W, x.device)
if key not in self._cache:
gy = torch.arange(H, dtype=torch.float32, device=x.device)
gx = torch.arange(W, dtype=torch.float32, device=x.device)
if self.normalize:
gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6)
yy, xx = torch.meshgrid(gy, gx, indexing="ij")
self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0)
return self._cache[key].expand(B, -1, -1, -1)
class SAM3VisionBackbone(nn.Module):
def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs)
self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True)
self.multiplex = multiplex
fpn_args = dict(device=device, dtype=dtype, operations=operations)
if multiplex:
scales = [4.0, 2.0, 1.0]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
else:
scales = [4.0, 2.0, 1.0, 0.5]
self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales])
def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False):
backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images)
if tracker_only:
# Skip detector FPN when only tracker features are needed (video tracking)
if self.multiplex:
tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs
else:
tracker_convs = self.sam2_convs
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return None, None, tracker_features, tracker_positions
features = [conv(backbone_out) for conv in self.convs]
positions = [cast_to_input(self.position_encoding(f), f) for f in features]
if self.multiplex:
if tracker_mode == "propagation":
tracker_convs = self.propagation_convs
elif tracker_mode == "interactive":
tracker_convs = self.interactive_convs
else:
return features, positions, None, None
elif need_tracker:
tracker_convs = self.sam2_convs
else:
return features, positions, None, None
tracker_features = [conv(backbone_out) for conv in tracker_convs]
tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features]
return features, positions, tracker_features, tracker_positions

1785
comfy/ldm/sam3/tracker.py Normal file

File diff suppressed because it is too large Load Diff

View File

View File

@ -0,0 +1,226 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from comfy.ldm.modules.diffusionmodules.openaimodel import Downsample, TimestepEmbedSequential, ResBlock, SpatialTransformer
from comfy.ldm.modules.attention import optimized_attention
class ZeroSFT(nn.Module):
def __init__(self, label_nc, norm_nc, concat_channels=0, dtype=None, device=None, operations=None):
super().__init__()
ks = 3
pw = ks // 2
self.param_free_norm = operations.GroupNorm(32, norm_nc + concat_channels, dtype=dtype, device=device)
nhidden = 128
self.mlp_shared = nn.Sequential(
operations.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw, dtype=dtype, device=device),
nn.SiLU()
)
self.zero_mul = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
self.zero_add = operations.Conv2d(nhidden, norm_nc + concat_channels, kernel_size=ks, padding=pw, dtype=dtype, device=device)
self.zero_conv = operations.Conv2d(label_nc, norm_nc, 1, 1, 0, dtype=dtype, device=device)
self.pre_concat = bool(concat_channels != 0)
def forward(self, c, h, h_ori=None, control_scale=1):
if h_ori is not None and self.pre_concat:
h_raw = torch.cat([h_ori, h], dim=1)
else:
h_raw = h
h = h + self.zero_conv(c)
if h_ori is not None and self.pre_concat:
h = torch.cat([h_ori, h], dim=1)
actv = self.mlp_shared(c)
gamma = self.zero_mul(actv)
beta = self.zero_add(actv)
h = self.param_free_norm(h)
h = torch.addcmul(h + beta, h, gamma)
if h_ori is not None and not self.pre_concat:
h = torch.cat([h_ori, h], dim=1)
return torch.lerp(h_raw, h, control_scale)
class _CrossAttnInner(nn.Module):
"""Inner cross-attention module matching the state_dict layout of the original CrossAttention."""
def __init__(self, query_dim, context_dim, heads, dim_head, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
)
def forward(self, x, context):
q = self.to_q(x)
k = self.to_k(context)
v = self.to_v(context)
return self.to_out(optimized_attention(q, k, v, self.heads))
class ZeroCrossAttn(nn.Module):
def __init__(self, context_dim, query_dim, dtype=None, device=None, operations=None):
super().__init__()
heads = query_dim // 64
dim_head = 64
self.attn = _CrossAttnInner(query_dim, context_dim, heads, dim_head, dtype=dtype, device=device, operations=operations)
self.norm1 = operations.GroupNorm(32, query_dim, dtype=dtype, device=device)
self.norm2 = operations.GroupNorm(32, context_dim, dtype=dtype, device=device)
def forward(self, context, x, control_scale=1):
b, c, h, w = x.shape
x_in = x
x = self.attn(
self.norm1(x).flatten(2).transpose(1, 2),
self.norm2(context).flatten(2).transpose(1, 2),
).transpose(1, 2).unflatten(2, (h, w))
return x_in + x * control_scale
class GLVControl(nn.Module):
"""SUPIR's Guided Latent Vector control encoder. Truncated UNet (input + middle blocks only)."""
def __init__(
self,
in_channels=4,
model_channels=320,
num_res_blocks=2,
attention_resolutions=(4, 2),
channel_mult=(1, 2, 4),
num_head_channels=64,
transformer_depth=(1, 2, 10),
context_dim=2048,
adm_in_channels=2816,
use_linear_in_transformer=True,
use_checkpoint=False,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.model_channels = model_channels
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
operations.Linear(model_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
)
self.label_emb = nn.Sequential(
nn.Sequential(
operations.Linear(adm_in_channels, time_embed_dim, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device),
)
)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
)
])
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(num_res_blocks):
layers = [
ResBlock(ch, time_embed_dim, 0, out_channels=mult * model_channels,
dtype=dtype, device=device, operations=operations)
]
ch = mult * model_channels
if ds in attention_resolutions:
num_heads = ch // num_head_channels
layers.append(
SpatialTransformer(ch, num_heads, num_head_channels,
depth=transformer_depth[level], context_dim=context_dim,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
dtype=dtype, device=device, operations=operations)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
if level != len(channel_mult) - 1:
self.input_blocks.append(
TimestepEmbedSequential(
Downsample(ch, True, out_channels=ch, dtype=dtype, device=device, operations=operations)
)
)
ds *= 2
num_heads = ch // num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
SpatialTransformer(ch, num_heads, num_head_channels,
depth=transformer_depth[-1], context_dim=context_dim,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint,
dtype=dtype, device=device, operations=operations),
ResBlock(ch, time_embed_dim, 0, dtype=dtype, device=device, operations=operations),
)
self.input_hint_block = TimestepEmbedSequential(
operations.Conv2d(in_channels, model_channels, 3, padding=1, dtype=dtype, device=device)
)
def forward(self, x, timesteps, xt, context=None, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
emb = self.time_embed(t_emb) + self.label_emb(y)
guided_hint = self.input_hint_block(x, emb, context)
hs = []
h = xt
for module in self.input_blocks:
if guided_hint is not None:
h = module(h, emb, context)
h += guided_hint
guided_hint = None
else:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
hs.append(h)
return hs
class SUPIR(nn.Module):
"""
SUPIR model containing GLVControl (control encoder) and project_modules (adapters).
State dict keys match the original SUPIR checkpoint layout:
control_model.* -> GLVControl
project_modules.* -> nn.ModuleList of ZeroSFT/ZeroCrossAttn
"""
def __init__(self, device=None, dtype=None, operations=None):
super().__init__()
self.control_model = GLVControl(dtype=dtype, device=device, operations=operations)
project_channel_scale = 2
cond_output_channels = [320] * 4 + [640] * 3 + [1280] * 3
project_channels = [int(c * project_channel_scale) for c in [160] * 4 + [320] * 3 + [640] * 3]
concat_channels = [320] * 2 + [640] * 3 + [1280] * 4 + [0]
cross_attn_insert_idx = [6, 3]
self.project_modules = nn.ModuleList()
for i in range(len(cond_output_channels)):
self.project_modules.append(ZeroSFT(
project_channels[i], cond_output_channels[i],
concat_channels=concat_channels[i],
dtype=dtype, device=device, operations=operations,
))
for i in cross_attn_insert_idx:
self.project_modules.insert(i, ZeroCrossAttn(
cond_output_channels[i], concat_channels[i],
dtype=dtype, device=device, operations=operations,
))

View File

@ -0,0 +1,103 @@
import torch
from comfy.ldm.modules.diffusionmodules.openaimodel import Upsample
class SUPIRPatch:
"""
Holds GLVControl (control encoder) + project_modules (ZeroSFT/ZeroCrossAttn adapters).
Runs GLVControl lazily on first patch invocation per step, applies adapters through
middle_block_after_patch, output_block_merge_patch, and forward_timestep_embed_patch.
"""
SIGMA_MAX = 14.6146
def __init__(self, model_patch, project_modules, hint_latent, strength_start, strength_end):
self.model_patch = model_patch # CoreModelPatcher wrapping GLVControl
self.project_modules = project_modules # nn.ModuleList of ZeroSFT/ZeroCrossAttn
self.hint_latent = hint_latent # encoded LQ image latent
self.strength_start = strength_start
self.strength_end = strength_end
self.cached_features = None
self.adapter_idx = 0
self.control_idx = 0
self.current_control_idx = 0
self.active = True
def _ensure_features(self, kwargs):
"""Run GLVControl on first call per step, cache results."""
if self.cached_features is not None:
return
x = kwargs["x"]
b = x.shape[0]
hint = self.hint_latent.to(device=x.device, dtype=x.dtype)
if hint.shape[0] != b:
hint = hint.expand(b, -1, -1, -1) if hint.shape[0] == 1 else hint.repeat((b + hint.shape[0] - 1) // hint.shape[0], 1, 1, 1)[:b]
self.cached_features = self.model_patch.model.control_model(
hint, kwargs["timesteps"], x,
kwargs["context"], kwargs["y"]
)
self.adapter_idx = len(self.project_modules) - 1
self.control_idx = len(self.cached_features) - 1
def _get_control_scale(self, kwargs):
if self.strength_start == self.strength_end:
return self.strength_end
sigma = kwargs["transformer_options"].get("sigmas")
if sigma is None:
return self.strength_end
s = sigma[0].item() if sigma.dim() > 0 else sigma.item()
t = min(s / self.SIGMA_MAX, 1.0)
return t * (self.strength_start - self.strength_end) + self.strength_end
def middle_after(self, kwargs):
"""middle_block_after_patch: run GLVControl lazily, apply last adapter after middle block."""
self.cached_features = None # reset from previous step
self.current_scale = self._get_control_scale(kwargs)
self.active = self.current_scale > 0
if not self.active:
return {"h": kwargs["h"]}
self._ensure_features(kwargs)
h = kwargs["h"]
h = self.project_modules[self.adapter_idx](
self.cached_features[self.control_idx], h, control_scale=self.current_scale
)
self.adapter_idx -= 1
self.control_idx -= 1
return {"h": h}
def output_block(self, h, hsp, transformer_options):
"""output_block_patch: ZeroSFT adapter fusion replaces cat([h, hsp]). Returns (h, None) to skip cat."""
if not self.active:
return h, hsp
self.current_control_idx = self.control_idx
h = self.project_modules[self.adapter_idx](
self.cached_features[self.control_idx], hsp, h, control_scale=self.current_scale
)
self.adapter_idx -= 1
self.control_idx -= 1
return h, None
def pre_upsample(self, layer, x, emb, context, transformer_options, output_shape, *args, **kw):
"""forward_timestep_embed_patch for Upsample: extra cross-attn adapter before upsample."""
block_type, _ = transformer_options["block"]
if block_type == "output" and self.active and self.cached_features is not None:
x = self.project_modules[self.adapter_idx](
self.cached_features[self.current_control_idx], x, control_scale=self.current_scale
)
self.adapter_idx -= 1
return layer(x, output_shape=output_shape)
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.cached_features = None
if self.hint_latent is not None:
self.hint_latent = self.hint_latent.to(device_or_dtype)
return self
def models(self):
return [self.model_patch]
def register(self, model_patcher):
"""Register all patches on a cloned model patcher."""
model_patcher.set_model_patch(self.middle_after, "middle_block_after_patch")
model_patcher.set_model_output_block_patch(self.output_block)
model_patcher.set_model_patch((Upsample, self.pre_upsample), "forward_timestep_embed_patch")

276
comfy/ldm/wan/ar_model.py Normal file
View File

@ -0,0 +1,276 @@
"""
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
autoregressive (frame-by-frame) video generation via Causal Forcing.
Weight-compatible with the standard WanModel -- same layer names, same shapes.
The difference is purely in the forward pass: this model processes one temporal
block at a time and maintains a KV cache across blocks.
Reference: https://github.com/thu-ml/Causal-Forcing
"""
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.wan.model import (
sinusoidal_embedding_1d,
repeat_e,
WanModel,
WanAttentionBlock,
)
import comfy.ldm.common_dit
import comfy.model_management
class CausalWanSelfAttention(nn.Module):
"""Self-attention with KV cache support for autoregressive inference."""
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
eps=1e-6, operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qk_norm = qk_norm
self.eps = eps
ops = operation_settings.get("operations")
device = operation_settings.get("device")
dtype = operation_settings.get("dtype")
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
v = self.v(x).view(b, s, n, d)
if kv_cache is None:
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v.view(b, s, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
else:
end = kv_cache["end"]
new_end = end + s
# Roped K and plain V go into cache
kv_cache["k"][:, end:new_end] = k
kv_cache["v"][:, end:new_end] = v
kv_cache["end"] = new_end
x = optimized_attention(
q.view(b, s, n * d),
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
x = self.o(x)
return x
class CausalWanAttentionBlock(WanAttentionBlock):
"""Transformer block with KV-cached self-attention and cross-attention caching."""
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps,
operation_settings=operation_settings)
self.self_attn = CausalWanSelfAttention(
dim, num_heads, window_size, qk_norm, eps,
operation_settings=operation_settings)
def forward(self, x, e, freqs, context, context_img_len=257,
kv_cache=None, crossattn_cache=None, transformer_options={}):
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# Self-attention with optional KV cache
x = x.contiguous()
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# Cross-attention with optional caching
if crossattn_cache is not None and crossattn_cache.get("is_init"):
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
x_ca = optimized_attention(
q, crossattn_cache["k"], crossattn_cache["v"],
heads=self.num_heads, transformer_options=transformer_options)
x = x + self.cross_attn.o(x_ca)
else:
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if crossattn_cache is not None:
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
crossattn_cache["v"] = self.cross_attn.v(context)
crossattn_cache["is_init"] = True
# FFN
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
class CausalWanModel(WanModel):
"""
Wan 2.1 diffusion backbone with causal KV-cache support.
Same weight structure as WanModel -- loads identical state dicts.
Adds forward_block() for frame-by-frame autoregressive inference.
"""
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
image_model=None,
device=None,
dtype=None,
operations=None):
super().__init__(
model_type=model_type, patch_size=patch_size, text_len=text_len,
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
wan_attn_block_class=CausalWanAttentionBlock,
device=device, dtype=dtype, operations=operations)
def forward_block(self, x, timestep, context, start_frame,
kv_caches, crossattn_caches, clip_fea=None):
"""
Forward one temporal block for autoregressive inference.
Args:
x: [B, C, block_frames, H, W] input latent for the current block
timestep: [B, block_frames] per-frame timesteps
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
start_frame: temporal frame index for RoPE offset
kv_caches: list of per-layer KV cache dicts
crossattn_caches: list of per-layer cross-attention cache dicts
clip_fea: optional CLIP features for I2V
Returns:
flow_pred: [B, C_out, block_frames, H, W] flow prediction
"""
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
bs, c, t, h, w = x.shape
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# Per-frame time embedding
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# Text embedding (reuses crossattn_cache after first block)
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea)
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
# RoPE for current block's temporal position
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
# Transformer blocks
for i, block in enumerate(self.blocks):
x = block(x, e=e0, freqs=freqs, context=context,
context_img_len=context_img_len,
kv_cache=kv_caches[i],
crossattn_cache=crossattn_caches[i])
# Head
x = self.head(x, e)
# Unpatchify
x = self.unpatchify(x, grid_sizes)
return x[:, :, :t, :h, :w]
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
"""Create fresh KV caches for all layers."""
caches = []
for _ in range(self.num_layers):
caches.append({
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"end": 0,
})
return caches
def init_crossattn_caches(self, batch_size, device, dtype):
"""Create fresh cross-attention caches for all layers."""
caches = []
for _ in range(self.num_layers):
caches.append({"is_init": False})
return caches
def reset_kv_caches(self, kv_caches):
"""Reset KV caches to empty (reuse allocated memory)."""
for cache in kv_caches:
cache["end"] = 0
def reset_crossattn_caches(self, crossattn_caches):
"""Reset cross-attention caches."""
for cache in crossattn_caches:
cache["is_init"] = False
@property
def head_dim(self):
return self.dim // self.num_heads
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
ar_state = transformer_options.get("ar_state")
if ar_state is not None:
bs = x.shape[0]
block_frames = x.shape[2]
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
return self.forward_block(
x=x, timestep=t_per_frame, context=context,
start_frame=ar_state["start_frame"],
kv_caches=ar_state["kv_caches"],
crossattn_caches=ar_state["crossattn_caches"],
clip_fea=clip_fea,
)
return super().forward(x, timestep, context, clip_fea=clip_fea,
time_dim_concat=time_dim_concat,
transformer_options=transformer_options, **kwargs)

View File

@ -17,6 +17,7 @@
""" """
from __future__ import annotations from __future__ import annotations
import comfy.memory_management
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
import comfy.model_base import comfy.model_base
@ -342,6 +343,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras key_map["base_model.model.{}".format(key_lora)] = k # Official base model loras
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # LyCORIS/LoKR format
if isinstance(model, comfy.model_base.ErnieImage):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["transformer.{}".format(key_lora)] = k
return key_map return key_map
@ -467,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
weight = old_weight weight = old_weight
return weight return weight
def prefetch_prepared_value(value, allocate_buffer, stream):
if isinstance(value, torch.Tensor):
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
elif isinstance(value, weight_adapter.WeightAdapterBase):
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
elif isinstance(value, tuple):
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
elif isinstance(value, list):
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
return value

View File

@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model import comfy.ldm.lumina.model
import comfy.ldm.wan.model import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.hunyuan3d.model import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model import comfy.ldm.hidream.model
import comfy.ldm.chroma.model import comfy.ldm.chroma.model
@ -53,8 +54,10 @@ import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model import comfy.ldm.anima.model
import comfy.ldm.trellis2.model import comfy.ldm.trellis2.model
import comfy.ldm.ace.ace_step15 import comfy.ldm.ace.ace_step15
import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@ -81,6 +84,7 @@ class ModelType(Enum):
IMG_TO_IMG = 9 IMG_TO_IMG = 9
FLOW_COSMOS = 10 FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11 IMG_TO_IMG_FLOW = 11
V_PREDICTION_DDPM = 12
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
@ -115,6 +119,8 @@ def model_sampling(model_config, model_type):
s = comfy.model_sampling.ModelSamplingCosmosRFlow s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW: elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW c = comfy.model_sampling.IMG_TO_IMG_FLOW
elif model_type == ModelType.V_PREDICTION_DDPM:
c = comfy.model_sampling.V_PREDICTION_DDPM
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@ -210,6 +216,11 @@ class BaseModel(torch.nn.Module):
if "latent_shapes" in extra_conds: if "latent_shapes" in extra_conds:
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
transformer_options = transformer_options.copy()
transformer_options["prefetch_dynamic_vbars"] = (
self.current_patcher is not None and self.current_patcher.is_dynamic()
)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output): if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output) model_output, _ = utils.pack_latents(model_output)
@ -579,8 +590,8 @@ class Stable_Zero123(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device)
self.cc_projection.weight.copy_(cc_projection_weight) self.cc_projection.weight = torch.nn.Parameter(cc_projection_weight.clone())
self.cc_projection.bias.copy_(cc_projection_bias) self.cc_projection.bias = torch.nn.Parameter(cc_projection_bias.clone())
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
@ -1356,6 +1367,13 @@ class WAN21(BaseModel):
return out return out
class WAN21_CausalAR(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
self.image_to_video = False
class WAN21_Vace(WAN21): class WAN21_Vace(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
@ -1984,3 +2002,63 @@ class ErnieImage(BaseModel):
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out return out
class SAM3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)
class CogVideoX(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel)
self.image_to_video = image_to_video
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
# Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent)
extra_channels = self.diffusion_model.in_channels - noise.shape[1]
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
shape = list(noise.shape)
shape[1] = extra_channels
return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device)
latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if noise.ndim == 5 and image.ndim == 5:
if image.shape[-3] < noise.shape[-3]:
image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0)
elif image.shape[-3] > noise.shape[-3]:
image = image[:, :, :noise.shape[-3]]
for i in range(0, image.shape[1], latent_dim):
image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
if image.shape[1] > extra_channels:
image = image[:, :extra_channels]
elif image.shape[1] < extra_channels:
repeats = extra_channels // image.shape[1]
remainder = extra_channels % image.shape[1]
parts = [image] * repeats
if remainder > 0:
parts.append(image[:, :remainder])
image = torch.cat(parts, dim=1)
return image
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
# OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR
if self.diffusion_model.ofs_proj_dim is not None:
ofs = kwargs.get("ofs", None)
if ofs is None:
noise = kwargs.get("noise", None)
ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
out['ofs'] = comfy.conds.CONDRegular(ofs)
return out

View File

@ -506,6 +506,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config return dit_config
if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX
dit_config = {}
dit_config["image_model"] = "cogvideox"
# Extract config from weight shapes
norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)]
time_embed_dim = norm1_weight.shape[1]
dim = norm1_weight.shape[0] // 6
dit_config["num_attention_heads"] = dim // 64
dit_config["attention_head_dim"] = 64
dit_config["time_embed_dim"] = time_embed_dim
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
# Detect in_channels from patch_embed
patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix)
if patch_proj_key in state_dict_keys:
w = state_dict[patch_proj_key]
if w.ndim == 4:
# Conv2d: [out, in, kh, kw] — CogVideoX 1.0
dit_config["in_channels"] = w.shape[1]
dit_config["patch_size"] = w.shape[2]
elif w.ndim == 2:
# Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5
dit_config["patch_size"] = 2
dit_config["patch_size_t"] = 2
dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32
text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix)
if text_proj_key in state_dict_keys:
dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1]
# Detect OFS embedding
ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix)
if ofs_key in state_dict_keys:
dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1]
# Detect positional embedding type
pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
if pos_key in state_dict_keys:
dit_config["use_learned_positional_embeddings"] = True
dit_config["use_rotary_positional_embeddings"] = False
else:
dit_config["use_learned_positional_embeddings"] = False
dit_config["use_rotary_positional_embeddings"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {} dit_config = {}
dit_config["image_model"] = "wan2.1" dit_config["image_model"] = "wan2.1"
@ -734,6 +782,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "ernie" dit_config["image_model"] = "ernie"
return dit_config return dit_config
if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1
if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "SAM3"
if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys:
dit_config["image_model"] = "SAM31"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None return None
@ -889,6 +945,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
return model_config return model_config
def unet_prefix_from_state_dict(state_dict): def unet_prefix_from_state_dict(state_dict):
# SAM3: detector.* and tracker.* at top level, no common prefix
if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict):
return ""
candidates = ["model.diffusion_model.", #ldm/sgm models candidates = ["model.diffusion_model.", #ldm/sgm models
"model.model.", #audio models "model.model.", #audio models
"net.", #cosmos "net.", #cosmos

View File

@ -31,6 +31,7 @@ from contextlib import nullcontext
import comfy.memory_management import comfy.memory_management
import comfy.utils import comfy.utils
import comfy.quant_ops import comfy.quant_ops
import comfy_aimdo.vram_buffer
class VRAMState(Enum): class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram DISABLED = 0 #No vram present: no need to move models to vram
@ -112,10 +113,6 @@ if args.directml is not None:
# torch_directml.disable_tiled_resources(True) # torch_directml.disable_tiled_resources(True)
lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default. lowvram_available = False #TODO: need to find a way to get free memory in directml before this can be enabled by default.
try:
import intel_extension_for_pytorch as ipex # noqa: F401
except:
pass
try: try:
_ = torch.xpu.device_count() _ = torch.xpu.device_count()
@ -583,9 +580,6 @@ class LoadedModel:
real_model = self.model.model real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
with torch.no_grad():
real_model = ipex.optimize(real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.real_model = weakref.ref(real_model) self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models) self.model_finalizer = weakref.finalize(real_model, cleanup_models)
@ -663,6 +657,7 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc() cleanup_models_gc()
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = [] unloaded_models = []
@ -726,13 +721,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
else: else:
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory()) minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models_temp = set() # Order-preserving dedup. A plain set() would randomize iteration order across runs
models_temp = {}
for m in models: for m in models:
models_temp.add(m) models_temp[m] = None
for mm in m.model_patches_models(): for mm in m.model_patches_models():
models_temp.add(mm) models_temp[mm] = None
models = models_temp models = list(models_temp)
models.reverse()
models_to_load = [] models_to_load = []
@ -1181,6 +1178,10 @@ stream_counters = {}
STREAM_CAST_BUFFERS = {} STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
def get_cast_buffer(offload_stream, device, size, ref): def get_cast_buffer(offload_stream, device, size, ref):
global LARGEST_CASTED_WEIGHT global LARGEST_CASTED_WEIGHT
@ -1214,13 +1215,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
return cast_buffer return cast_buffer
def get_aimdo_cast_buffer(offload_stream, device):
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
if cast_buffer is None:
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def reset_cast_buffers(): def reset_cast_buffers():
global LARGEST_CASTED_WEIGHT global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS: LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
offload_stream.synchronize() for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize() synchronize()
STREAM_CAST_BUFFERS.clear() STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
soft_empty_cache() soft_empty_cache()
def get_offload_stream(device): def get_offload_stream(device):
@ -1580,10 +1594,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
if is_intel_xpu(): if is_intel_xpu():
if torch_version_numeric < (2, 3): return torch.xpu.get_device_properties(device).has_fp16
return True
else:
return torch.xpu.get_device_properties(device).has_fp16
if is_ascend_npu(): if is_ascend_npu():
return True return True
@ -1649,10 +1660,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False return False
if is_intel_xpu(): if is_intel_xpu():
if torch_version_numeric < (2, 3): return torch.xpu.is_bf16_supported()
return True
else:
return torch.xpu.is_bf16_supported()
if is_ascend_npu(): if is_ascend_npu():
return True return True
@ -1783,6 +1791,7 @@ def soft_empty_cache(force=False):
if cpu_state == CPUState.MPS: if cpu_state == CPUState.MPS:
torch.mps.empty_cache() torch.mps.empty_cache()
elif is_intel_xpu(): elif is_intel_xpu():
torch.xpu.synchronize()
torch.xpu.empty_cache() torch.xpu.empty_cache()
elif is_ascend_npu(): elif is_ascend_npu():
torch.npu.empty_cache() torch.npu.empty_cache()
@ -1801,7 +1810,7 @@ def debug_memory_summary():
return torch.cuda.memory.memory_summary() return torch.cuda.memory.memory_summary()
return "" return ""
class InterruptProcessingException(Exception): class InterruptProcessingException(BaseException):
pass pass
interrupt_processing_mutex = threading.RLock() interrupt_processing_mutex = threading.RLock()

View File

@ -31,6 +31,7 @@ import comfy.float
import comfy.hooks import comfy.hooks
import comfy.lora import comfy.lora
import comfy.model_management import comfy.model_management
import comfy.ops
import comfy.patcher_extension import comfy.patcher_extension
import comfy.utils import comfy.utils
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
@ -120,9 +121,20 @@ class LowVramPatch:
self.patches = patches self.patches = patches
self.convert_func = convert_func # TODO: remove self.convert_func = convert_func # TODO: remove
self.set_func = set_func self.set_func = set_func
self.prepared_patches = None
def prepare(self, allocate_buffer, stream):
self.prepared_patches = [
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
for patch in self.patches[self.key]
]
def clear_prepared(self):
self.prepared_patches = None
def __call__(self, weight): def __call__(self, weight):
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2 LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
@ -506,6 +518,10 @@ class ModelPatcher:
def set_model_noise_refiner_patch(self, patch): def set_model_noise_refiner_patch(self, patch):
self.set_model_patch(patch, "noise_refiner") self.set_model_patch(patch, "noise_refiner")
def set_model_middle_block_after_patch(self, patch):
self.set_model_patch(patch, "middle_block_after_patch")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {}) rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x rope_options["scale_x"] = scale_x
@ -681,9 +697,9 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False):
weight, set_func, convert_func = get_key_weight(self.model, key) weight, set_func, convert_func = get_key_weight(self.model, key)
if key not in self.patches: if key not in self.patches and not force_cast:
return weight return weight
inplace_update = self.weight_inplace_update or inplace_update inplace_update = self.weight_inplace_update or inplace_update
@ -691,7 +707,7 @@ class ModelPatcher:
if key not in self.backup and not return_weight: if key not in self.backup and not return_weight:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
temp_dtype = comfy.model_management.lora_compute_dtype(device_to) temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None
if device_to is not None: if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else: else:
@ -699,9 +715,10 @@ class ModelPatcher:
if convert_func is not None: if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True) temp_weight = convert_func(temp_weight, inplace=True)
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight
if set_func is None: if set_func is None:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) if key in self.patches:
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key))
if return_weight: if return_weight:
return out_weight return out_weight
elif inplace_update: elif inplace_update:
@ -851,7 +868,9 @@ class ModelPatcher:
if m.comfy_patched_weights == True: if m.comfy_patched_weights == True:
continue continue
for param in params: for param, param_value in params.items():
if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False):
comfy.ops.disable_weight_init._zero_init_parameter(m, param)
key = key_param_name_to_key(n, param) key = key_param_name_to_key(n, param)
self.unpin_weight(key) self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to) self.patch_weight_to_device(key, device_to=device_to)
@ -1580,7 +1599,7 @@ class ModelPatcherDynamic(ModelPatcher):
key = key_param_name_to_key(n, param_key) key = key_param_name_to_key(n, param_key)
if key in self.backup: if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to) self.patch_weight_to_device(key, device_to=device_to, force_cast=True)
weight, _, _ = get_key_weight(self.model, key) weight, _, _ = get_key_weight(self.model, key)
if weight is not None: if weight is not None:
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
@ -1605,6 +1624,10 @@ class ModelPatcherDynamic(ModelPatcher):
m._v = vbar.alloc(v_weight_size) m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size allocated_size += v_weight_size
for param in params:
if param not in ("weight", "bias"):
force_load_param(self, param, device_to)
else: else:
for param in params: for param in params:
key = key_param_name_to_key(n, param) key = key_param_name_to_key(n, param)

66
comfy/model_prefetch.py Normal file
View File

@ -0,0 +1,66 @@
import comfy_aimdo.model_vbar
import comfy.model_management
import comfy.ops
PREFETCH_QUEUES = []
def cleanup_prefetched_modules(comfy_modules):
for s in comfy_modules:
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
continue
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
if prefetch["signature"] is not None:
comfy_aimdo.model_vbar.vbar_unpin(s._v)
delattr(s, "_prefetch")
def cleanup_prefetch_queues():
global PREFETCH_QUEUES
for queue in PREFETCH_QUEUES:
for entry in queue:
if entry is None or not isinstance(entry, tuple):
continue
_, prefetch_state = entry
comfy_modules = prefetch_state[1]
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
PREFETCH_QUEUES = []
def prefetch_queue_pop(queue, device, module):
if queue is None:
return
consumed = queue.pop(0)
if consumed is not None:
offload_stream, prefetch_state = consumed
if offload_stream is not None:
offload_stream.wait_stream(comfy.model_management.current_stream(device))
_, comfy_modules = prefetch_state
if comfy_modules is not None:
cleanup_prefetched_modules(comfy_modules)
prefetch = queue[0]
if prefetch is not None:
comfy_modules = []
for s in prefetch.modules():
if hasattr(s, "_v"):
comfy_modules.append(s)
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))
def make_prefetch_queue(queue, device, transformer_options):
if (not transformer_options.get("prefetch_dynamic_vbars", False)
or comfy.model_management.NUM_STREAMS == 0
or comfy.model_management.is_device_cpu(device)
or not comfy.model_management.device_supports_non_blocking(device)):
return None
queue = [None] + queue + [None]
PREFETCH_QUEUES.append(queue)
return queue

View File

@ -54,6 +54,30 @@ class V_PREDICTION(EPS):
sigma = reshape_sigma(sigma, model_output.ndim) sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class V_PREDICTION_DDPM:
"""CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v.
x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v
= x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1)
"""
def calculate_input(self, sigma, noise):
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class EDM(V_PREDICTION): class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim) sigma = reshape_sigma(sigma, model_output.ndim)

View File

@ -79,37 +79,68 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): def materialize_meta_param(s, param_keys):
for param_key in param_keys:
param = getattr(s, param_key, None)
if param is not None and getattr(param, "is_meta", False):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
# FIXME: add n=1 cache hit fast path
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
offload_stream = None offload_stream = None
xfer_dest = None cast_buffer = None
cast_buffer_offset = 0
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
nonlocal cast_buffer
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
return
current_size = 0 if cast_buffer is None else cast_buffer.size()
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = None
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
def get_cast_buffer(buffer_size):
nonlocal offload_stream
nonlocal cast_buffer
nonlocal cast_buffer_offset
if buffer_size == 0:
return None
if offload_stream is None:
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
cast_buffer_offset += buffer_size
return buffer
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
prefetch = {
"signature": signature,
"resident": resident,
}
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident: if resident:
weight = s._v_weight s._prefetch = prefetch
bias = s._v_bias continue
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident: materialize_meta_param(s, ["weight", "bias"])
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None cast_dest = None
needs_cast = False
xfer_source = [ s.weight, s.bias ] xfer_source = [ s.weight, s.bias ]
@ -121,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None: if data is None:
continue continue
if data.dtype != geometry.dtype: if data.dtype != geometry.dtype:
needs_cast = True
cast_dest = xfer_dest cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None xfer_dest = None
break break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source) dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device) ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
if xfer_dest is None and offload_stream is not None:
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None:
offload_stream = comfy.model_management.get_offload_stream(device)
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
if xfer_dest is None: if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) xfer_dest = get_cast_buffer(dest_size)
offload_stream = None
if signature is None and pin is None: if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s) comfy.pinned_memory.pin_memory(s)
@ -149,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ] xfer_source = [ pin ]
#send it over #send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
comfy.model_management.sync_stream(device, offload_stream)
if cast_dest is not None: for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
ensure_offload_stream(s, cast_buffer_offset, False)
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
prefetch["cast_geometry"] = cast_geometry
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
return offload_stream
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
prefetch = getattr(s, "_prefetch", None)
if prefetch["resident"]:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = prefetch["xfer_dest"]
if prefetch["needs_cast"]:
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None: if post_cast is not None:
post_cast.copy_(pre_cast) post_cast.copy_(pre_cast)
xfer_dest = cast_dest xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0] weight = params[0]
bias = params[1] bias = params[1]
if signature is not None: if prefetch["signature"] is not None:
s._v_weight = weight s._v_weight = weight
s._v_bias = bias s._v_bias = bias
s._v_signature=signature s._v_signature = prefetch["signature"]
def post_cast(s, param_key, x, dtype, resident, update_weight): def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None) lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", []) fns = getattr(s, param_key + "_function", [])
if x is None:
return None
orig = x orig = x
def to_dequant(tensor, dtype): def to_dequant(tensor, dtype):
@ -197,14 +248,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x) x = f(x)
return x return x
update_weight = signature is not None update_weight = prefetch["signature"] is not None
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
if bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
weight = post_cast(s, "weight", weight, dtype, resident, update_weight) if prefetch["signature"] is not None:
if s.bias is not None: prefetch["resident"] = True
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
#FIXME: weird offload return protocol return weight, bias
return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@ -222,10 +274,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None: if device is None:
device = input.device device = input.device
def format_return(result, offloadable):
weight, bias, offload_stream = result
return (weight, bias, offload_stream) if offloadable else (weight, bias)
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"): if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return format_return((weight, bias, (None, None, None)), offloadable)
prefetched = hasattr(s, "_prefetch")
offload_stream = None
offload_device = None
if not prefetched:
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None:
offload_device = device
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
lowvram_fn.clear_prepared()
delattr(s, "_prefetch")
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
if offloadable and (device != s.weight.device or if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)): (s.bias is not None and device != s.bias.device)):
@ -272,11 +360,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
if offloadable: return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream): def uncast_bias_weight(s, weight, bias, offload_stream):
@ -306,6 +390,12 @@ class CastWeightBiasOp:
bias_function = [] bias_function = []
class disable_weight_init: class disable_weight_init:
@staticmethod
def _zero_init_parameter(module, name):
param = getattr(module, name)
device = None if getattr(param, "is_meta", False) else param.device
setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False))
@staticmethod @staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata, def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape, missing_keys, unexpected_keys, weight_shape,
@ -1151,7 +1241,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
if param is None: if param is None:
continue continue
p = fn(param) p = fn(param)
if p.is_inference(): if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone() p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items(): for key, buf in self._buffers.items():
@ -1159,6 +1249,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf) self._buffers[key] = fn(buf)
return self return self
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
if scale is not None:
scale = scale.float()
manually_loaded_keys.append(scale_key)
params = layout_cls.Params(
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.num_embeddings, self.embedding_dim),
)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight
# Optimized path: lookup in fp8, dequantize only the selected rows.
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
if isinstance(qdata, QuantizedTensor):
scale = qdata._params.scale
qdata = qdata._qdata
else:
scale = None
x = torch.nn.functional.embedding(
input, qdata, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
uncast_bias_weight(self, qdata, None, offload_stream)
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
x = x.to(dtype=target_dtype)
if scale is not None and scale != 1.0:
x = x * scale.to(dtype=target_dtype)
return x
# Fallback for non-quantized or weight_function (LoRA) case
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
return MixedPrecisionOps return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):

View File

@ -2,7 +2,6 @@ import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy_aimdo.host_buffer import comfy_aimdo.host_buffer
import comfy_aimdo.torch import comfy_aimdo.torch
import psutil
from comfy.cli_args import args from comfy.cli_args import args
@ -12,11 +11,6 @@ def get_pin(module):
def pin_memory(module): def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return return
#FIXME: This is a RAM cache trigger event
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
#we split the difference and assume half the RAM cache headroom is for us
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
comfy.memory_management.extra_ram_release(ram_headroom)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])

View File

@ -1,6 +1,8 @@
import torch import torch
import logging import logging
from comfy.cli_args import args
try: try:
import comfy_kitchen as ck import comfy_kitchen as ck
from comfy_kitchen.tensor import ( from comfy_kitchen.tensor import (
@ -21,7 +23,15 @@ try:
ck.registry.disable("cuda") ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton") if args.enable_triton_backend:
try:
import triton
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
except ImportError as e:
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
ck.registry.disable("triton")
else:
ck.registry.disable("triton")
for k, v in ck.list_backends().items(): for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}") logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e: except ImportError as e:

View File

@ -3,6 +3,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm RMSNorm = torch.nn.RMSNorm
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6): def rms_norm(x, weight=None, eps=1e-6):
if weight is None: if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)

View File

@ -89,7 +89,8 @@ def get_additional_models(conds, dtype):
gligen += get_models_from_cond(conds[k], "gligen") gligen += get_models_from_cond(conds[k], "gligen")
add_models += get_models_from_cond(conds[k], "additional_models") add_models += get_models_from_cond(conds[k], "additional_models")
control_nets = set(cnets) # Order-preserving dedup. A plain set() would randomize iteration order across runs
control_nets = list(dict.fromkeys(cnets))
inference_memory = 0 inference_memory = 0
control_models = [] control_models = []

View File

@ -12,12 +12,14 @@ from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.lightricks.vae.audio_vae
import comfy.ldm.cosmos.vae import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae import comfy.ldm.wan.vae
import comfy.ldm.trellis2.vae import comfy.ldm.trellis2.vae
import comfy.ldm.wan.vae2_2 import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert import comfy.pixel_space_convert
@ -64,6 +66,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35 import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -478,7 +481,10 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd: elif "taesd_decoder.1.weight" in sd:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] if isinstance(metadata, dict) and "tae_latent_channels" in metadata:
self.latent_channels = metadata["tae_latent_channels"]
else:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA() self.first_stage_model = StageA()
@ -661,6 +667,17 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2
self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels)
self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd: elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True ddconfig["conv3d"] = True
@ -815,6 +832,24 @@ class VAE:
self.downscale_index_formula = (4, 8, 8) self.downscale_index_formula = (4, 8, 8)
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio
sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."})
self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata)
self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
self.latent_channels = self.first_stage_model.latent_channels
self.audio_sample_rate_output = self.first_stage_model.output_sample_rate
self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes
self.output_channels = 2
self.pad_channel_value = "replicate"
self.upscale_ratio = 4096
self.downscale_ratio = 4096
self.latent_dim = 2
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float32]
self.disable_offload = True
self.extra_1d_channel = 16
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None self.first_stage_model = None
@ -1247,6 +1282,9 @@ class TEModel(Enum):
QWEN35_9B = 26 QWEN35_9B = 26
QWEN35_27B = 27 QWEN35_27B = 27
MINISTRAL_3_3B = 28 MINISTRAL_3_3B = 28
GEMMA_4_E4B = 29
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
def detect_te_model(sd): def detect_te_model(sd):
@ -1272,6 +1310,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd: if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd: if 'model.layers.0.self_attn.q_norm.weight' in sd:
@ -1411,6 +1455,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else: else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B: elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

View File

@ -27,6 +27,7 @@ import comfy.text_encoders.anima
import comfy.text_encoders.ace15 import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -1166,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
class WAN21_CausalAR_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"causal_ar": True,
}
sampling_settings = {
"shift": 5.0,
}
def __init__(self, unet_config):
super().__init__(unet_config)
self.unet_config.pop("causal_ar", None)
def get_model(self, state_dict, prefix="", device=None):
return model_base.WAN21_CausalAR(self, device=device)
class WAN21_I2V(WAN21_T2V): class WAN21_I2V(WAN21_T2V):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -1805,6 +1825,185 @@ class ErnieImage(supported_models_base.BASE):
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, Trellis2]
models += [SVD_img2vid] class SAM3(supported_models_base.BASE):
unet_config = {"image_model": "SAM3"}
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
text_encoder_key_prefix = ["detector.backbone.language_backbone."]
unet_extra_prefix = ""
def process_clip_state_dict(self, state_dict):
clip_keys = getattr(self, "_clip_stash", {})
clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True)
clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.")
return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")}
def process_unet_state_dict(self, state_dict):
self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k}
# SAM3.1: remap tracker.model.* -> tracker.*
for k in list(state_dict.keys()):
if k.startswith("tracker.model."):
state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k)
# SAM3.1: remove per-block freqs_cis buffers (computed dynamically)
for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]:
state_dict.pop(k)
# Split fused QKV projections
for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]:
t = state_dict.pop(k)
base, suffix = k.rsplit(".in_proj_", 1)
s = ".weight" if suffix == "weight" else ".bias"
d = t.shape[0] // 3
state_dict[base + ".q_proj" + s] = t[:d]
state_dict[base + ".k_proj" + s] = t[d:2*d]
state_dict[base + ".v_proj" + s] = t[2*d:]
# Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer
for k in list(state_dict.keys()):
if "sam_mask_decoder.transformer." not in k:
continue
new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.")
if new_k != k:
state_dict[new_k] = state_dict.pop(k)
return state_dict
def get_model(self, state_dict, prefix="", device=None):
return model_base.SAM3(self, device=device)
def clip_target(self, state_dict={}):
import comfy.text_encoders.sam3_clip
return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper)
class SAM31(SAM3):
unet_config = {"image_model": "SAM31"}
class CogVideoX_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "cogvideox",
}
sampling_settings = {
"linear_start": 0.00085,
"linear_end": 0.012,
"beta_schedule": "linear",
"zsnr": True,
}
unet_extra_config = {}
latent_format = latent_formats.CogVideoX
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
if self.unet_config.get("patch_size_t") is not None:
self.unet_config.setdefault("sample_height", 96)
self.unet_config.setdefault("sample_width", 170)
self.unet_config.setdefault("sample_frames", 81)
out = model_base.CogVideoX(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.cogvideo.CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel)
class CogVideoX_I2V(CogVideoX_T2V):
unet_config = {
"image_model": "cogvideox",
"in_channels": 32,
}
def get_model(self, state_dict, prefix="", device=None):
if self.unet_config.get("patch_size_t") is not None:
self.unet_config.setdefault("sample_height", 96)
self.unet_config.setdefault("sample_width", 170)
self.unet_config.setdefault("sample_frames", 81)
out = model_base.CogVideoX(self, image_to_video=True, device=device)
return out
models = [
LotusD,
Stable_Zero123,
SD15_instructpix2pix,
SD15,
SD20,
SD21UnclipL,
SD21UnclipH,
SDXL_instructpix2pix,
SDXLRefiner,
SDXL,
SSD1B,
KOALA_700M,
KOALA_1B,
Segmind_Vega,
SD_X4Upscaler,
Stable_Cascade_C,
Stable_Cascade_B,
SV3D_u,
SV3D_p,
SD3,
StableAudio,
AuraFlow,
PixArtAlpha,
PixArtSigma,
HunyuanDiT,
HunyuanDiT1,
FluxInpaint,
Flux,
LongCatImage,
FluxSchnell,
GenmoMochi,
LTXV,
LTXAV,
HunyuanVideo15_SR_Distilled,
HunyuanVideo15,
HunyuanImage21Refiner,
HunyuanImage21,
HunyuanVideoSkyreelsI2V,
HunyuanVideoI2V,
HunyuanVideo,
CosmosT2V,
CosmosI2V,
CosmosT2IPredict2,
CosmosI2VPredict2,
ZImagePixelSpace,
ZImage,
Lumina2,
WAN22_T2V,
WAN21_CausalAR_T2V,
WAN21_T2V,
WAN21_I2V,
WAN21_FunControl2V,
WAN21_Vace,
WAN21_Camera,
WAN22_Camera,
WAN22_S2V,
WAN21_HuMo,
WAN22_Animate,
WAN21_FlowRVS,
WAN21_SCAIL,
Hunyuan3Dv2mini,
Hunyuan3Dv2,
Hunyuan3Dv2_1,
HiDream,
Chroma,
ChromaRadiance,
ACEStep,
ACEStep15,
Omnigen2,
QwenImage,
Flux2,
Kandinsky5Image,
Kandinsky5,
Anima,
RT_DETR_v4,
ErnieImage,
SAM3,
SAM31,
CogVideoX_I2V,
CogVideoX_T2V,
SVD_img2vid,
Trellis2
]

View File

@ -7,6 +7,7 @@ from tqdm.auto import tqdm
from collections import namedtuple, deque from collections import namedtuple, deque
import comfy.ops import comfy.ops
import comfy.model_management
operations=comfy.ops.disable_weight_init operations=comfy.ops.disable_weight_init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
@ -47,11 +48,14 @@ class TGrow(nn.Module):
x = self.conv(x) x = self.conv(x)
return x.reshape(-1, C, H, W) return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar): def apply_model_with_memblocks(model, x, parallel, show_progress_bar, output_device=None,
patch_size=1, decode=False):
B, T, C, H, W = x.shape B, T, C, H, W = x.shape
if parallel: if parallel:
x = x.reshape(B*T, C, H, W) x = x.reshape(B*T, C, H, W)
if not decode and patch_size > 1:
x = F.pixel_unshuffle(x, patch_size)
# parallel over input timesteps, iterate over blocks # parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar): for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock): if isinstance(b, MemBlock):
@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
x = b(x, mem) x = b(x, mem)
else: else:
x = b(x) x = b(x)
BT, C, H, W = x.shape if decode and patch_size > 1:
T = BT // B x = F.pixel_shuffle(x, patch_size)
x = x.view(B, T, C, H, W) x = x.view(B, x.shape[0] // B, *x.shape[1:])
x = x.to(output_device)
else: else:
out = [] out = []
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))]) # Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views).
# Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode).
work_queue = deque([TWorkItem(xt.squeeze(1), 0) for xt in x.chunk(T, dim=1)])
progress_bar = tqdm(range(T), disable=not show_progress_bar) progress_bar = tqdm(range(T), disable=not show_progress_bar)
mem = [None] * len(model) mem = [None] * len(model)
while work_queue: while work_queue:
xt, i = work_queue.popleft() xt, i = work_queue.popleft()
if i == 0: if i == 0:
progress_bar.update(1) progress_bar.update(1)
if not decode and patch_size > 1:
xt = F.pixel_unshuffle(xt, patch_size)
if i == len(model): if i == len(model):
out.append(xt) if decode and patch_size > 1:
xt = F.pixel_shuffle(xt, patch_size)
out.append(xt.to(output_device))
del xt del xt
else: else:
b = model[i] b = model[i]
@ -165,24 +176,20 @@ class TAEHV(nn.Module):
def encode(self, x, **kwargs): def encode(self, x, **kwargs):
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if self.patch_size > 1:
B, T, C, H, W = x.shape
x = x.reshape(B * T, C, H, W)
x = F.pixel_unshuffle(x, self.patch_size)
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
if x.shape[1] % self.t_downscale != 0: if x.shape[1] % self.t_downscale != 0:
# pad at end to multiple of t_downscale # pad at end to multiple of t_downscale
n_pad = self.t_downscale - x.shape[1] % self.t_downscale n_pad = self.t_downscale - x.shape[1] % self.t_downscale
padding = x[:, -1:].repeat_interleave(n_pad, dim=1) padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1) x = torch.cat([x, padding], 1)
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar,
patch_size=self.patch_size).movedim(2, 1)
return self.process_out(x) return self.process_out(x)
def decode(self, x, **kwargs): def decode(self, x, **kwargs):
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar,
if self.patch_size > 1: output_device=comfy.model_management.intermediate_device(),
x = F.pixel_shuffle(x, self.patch_size) patch_size=self.patch_size, decode=True)
return x[:, self.frames_to_trim:].movedim(2, 1) return x[:, self.frames_to_trim:].movedim(2, 1)

View File

@ -17,32 +17,79 @@ class Clamp(nn.Module):
return torch.tanh(x / 3) * 3 return torch.tanh(x / 3) * 3
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_in, n_out): def __init__(self, n_in: int, n_out: int, use_midblock_gn: bool = False):
super().__init__() super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU() self.fuse = nn.ReLU()
def forward(self, x): if not use_midblock_gn:
self.pool = None
return
n_gn = n_in * 4
self.pool = nn.Sequential(
comfy.ops.disable_weight_init.Conv2d(n_in, n_gn, 1, bias=False),
comfy.ops.disable_weight_init.GroupNorm(4, n_gn),
nn.ReLU(inplace=True),
comfy.ops.disable_weight_init.Conv2d(n_gn, n_in, 1, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.pool is not None:
x = x + self.pool(x)
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))
def Encoder(latent_channels=4): class Encoder(nn.Sequential):
return nn.Sequential( def __init__(self, latent_channels: int = 4, use_gn: bool = False):
conv(3, 64), Block(64, 64), super().__init__(
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, latent_channels), conv(64, 64, stride=2, bias=False), Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn),
) conv(64, latent_channels),
)
class Decoder(nn.Sequential):
def __init__(self, latent_channels: int = 4, use_gn: bool = False):
super().__init__(
Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class DecoderFlux2(Decoder):
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
if latent_channels != 128 or not use_gn:
raise ValueError("Unexpected parameters for Flux2 TAE module")
super().__init__(latent_channels=32, use_gn=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
x = (
x
.reshape(B, 32, 2, 2, H, W)
.permute(0, 1, 4, 2, 5, 3)
.reshape(B, 32, H * 2, W * 2)
)
return super().forward(x)
class EncoderFlux2(Encoder):
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
if latent_channels != 128 or not use_gn:
raise ValueError("Unexpected parameters for Flux2 TAE module")
super().__init__(latent_channels=32, use_gn=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
result = super().forward(x)
B, C, H, W = result.shape
return (
result
.reshape(B, C, H // 2, 2, W // 2, 2)
.permute(0, 1, 3, 5, 2, 4)
.reshape(B, 128, H // 2, W // 2)
)
def Decoder(latent_channels=4):
return nn.Sequential(
Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module): class TAESD(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
@ -51,8 +98,15 @@ class TAESD(nn.Module):
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4): def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
"""Initialize pretrained TAESD on the given device from the given checkpoints.""" """Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__() super().__init__()
self.taesd_encoder = Encoder(latent_channels=latent_channels) if latent_channels == 128:
self.taesd_decoder = Decoder(latent_channels=latent_channels) encoder_class = EncoderFlux2
decoder_class = DecoderFlux2
else:
encoder_class = Encoder
decoder_class = Decoder
self.taesd_encoder = encoder_class(latent_channels=latent_channels)
self.taesd_decoder = decoder_class(latent_channels=latent_channels)
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0)) self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None: if encoder_path is not None:
@ -61,19 +115,19 @@ class TAESD(nn.Module):
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod @staticmethod
def scale_latents(x): def scale_latents(x: torch.Tensor) -> torch.Tensor:
"""raw latents -> [0, 1]""" """raw latents -> [0, 1]"""
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1) return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
@staticmethod @staticmethod
def unscale_latents(x): def unscale_latents(x: torch.Tensor) -> torch.Tensor:
"""[0, 1] -> raw latents""" """[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x): def decode(self, x: torch.Tensor) -> torch.Tensor:
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale) x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2) x_sample = x_sample.sub(0.5).mul(2)
return x_sample return x_sample
def encode(self, x): def encode(self, x: torch.Tensor) -> torch.Tensor:
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift

View File

@ -0,0 +1,6 @@
import comfy.text_encoders.sd3_clip
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)

View File

@ -3,7 +3,7 @@ from comfy import sd1_clip
import comfy.text_encoders.llama import comfy.text_encoders.llama
class Ministral3_3BTokenizer(Mistral3Tokenizer): class Ministral3_3BTokenizer(Mistral3Tokenizer):
def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_data={}): def __init__(self, embedding_directory=None, embedding_size=5120, embedding_key='ministral3_3b', tokenizer_data={}):
return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data) return super().__init__(embedding_directory=embedding_directory, embedding_size=embedding_size, embedding_key=embedding_key, tokenizer_data=tokenizer_data)
class ErnieTokenizer(sd1_clip.SD1Tokenizer): class ErnieTokenizer(sd1_clip.SD1Tokenizer):
@ -35,4 +35,4 @@ def te(dtype_llama=None, llama_quantization_metadata=None):
model_options = model_options.copy() model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options) super().__init__(device=device, dtype=dtype, model_options=model_options)
return ErnieTEModel return ErnieTEModel_

File diff suppressed because it is too large Load Diff

View File

@ -82,6 +82,7 @@ class Ministral3_3BConfig:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
stop_tokens = [2]
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
@ -520,7 +521,7 @@ class Attention(nn.Module):
else: else:
present_key_value = (xk, xv, index + num_tokens) present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window: if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
xk = xk[:, :, -sliding_window:] xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:] xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
@ -532,12 +533,12 @@ class Attention(nn.Module):
return self.o_proj(output), present_key_value return self.o_proj(output), present_key_value
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
super().__init__() super().__init__()
ops = ops or nn intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu": if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh": elif config.mlp_activation == "gelu_pytorch_tanh":
@ -646,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
return x, present_key_value return x, present_key_value
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
class ScaledEmbedding(ops.Embedding):
def forward(self, input_ids, out_dtype=None):
return super().forward(input_ids, out_dtype=out_dtype) * scale
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
class Llama2_(nn.Module): class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2 transformer = TransformerBlockGemma2
self.normalize_in = True self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
else: else:
transformer = TransformerBlock transformer = TransformerBlock
self.normalize_in = False self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
transformer(config, index=i, device=device, dtype=dtype, ops=ops) transformer(config, index=i, device=device, dtype=dtype, ops=ops)
@ -689,15 +691,12 @@ class Llama2_(nn.Module):
self.config.rope_dims, self.config.rope_dims,
device=device) device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
if embeds is not None: if embeds is not None:
x = embeds x = embeds
else: else:
x = self.embed_tokens(x, out_dtype=dtype) x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1] seq_len = x.shape[1]
past_len = 0 past_len = 0
if past_key_values is not None and len(past_key_values) > 0: if past_key_values is not None and len(past_key_values) > 0:
@ -849,7 +848,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0): def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
device = embeds.device device = embeds.device
if stop_tokens is None: if stop_tokens is None:
@ -874,14 +873,16 @@ class BaseGenerate:
pbar = comfy.utils.ProgressBar(max_length) pbar = comfy.utils.ProgressBar(max_length)
# Generation loop # Generation loop
current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"): for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
logits = self.logits(x)[:, -1] logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item() token_id = next_token[0].item()
generated_token_ids.append(token_id) generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype) embeds = self.model.embed_tokens(next_token).to(execution_dtype)
current_input_ids = next_token if initial_input_ids is not None else None
pbar.update(1) pbar.update(1)
if token_id in stop_tokens: if token_id in stop_tokens:
@ -969,7 +970,7 @@ class Mistral3Small24B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Ministral3_3B(BaseLlama, torch.nn.Module): class Ministral3_3B(BaseLlama, BaseQwen3, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Ministral3_3BConfig(**config_dict) config = Ministral3_3BConfig(**config_dict)

View File

@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens] tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn> return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module): class DualLinearProjection(torch.nn.Module):

View File

@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device): def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device) embeds, _, _, _ = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return embeds return embeds
class LuminaModel(sd1_clip.SD1ClipModel): class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops) Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)

View File

@ -0,0 +1,97 @@
import re
from comfy import sd1_clip
SAM3_CLIP_CONFIG = {
"architectures": ["CLIPTextModel"],
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"max_position_embeddings": 32,
"projection_dim": 512,
"vocab_size": 49408,
"layer_norm_eps": 1e-5,
"eos_token_id": 49407,
}
class SAM3ClipModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options)
class SAM3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data)
self.disable_weights = True
def _parse_prompts(text):
"""Split comma-separated prompts with optional :N max detections per category"""
text = text.replace("(", "").replace(")", "")
parts = [p.strip() for p in text.split(",") if p.strip()]
result = []
for part in parts:
m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part)
if m:
text_part = m.group(1).strip()
val = m.group(2)
max_det = max(1, round(float(val)))
result.append((text_part, max_det))
else:
result.append((part, 1))
return result
class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip")
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
parsed = _parse_prompts(text)
if len(parsed) <= 1 and (not parsed or parsed[0][1] == 1):
return super().tokenize_with_weights(text, return_word_ids, **kwargs)
# Tokenize each prompt part separately, store per-part batches and metadata
inner = getattr(self, self.clip)
per_prompt = []
for prompt_text, max_det in parsed:
batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs)
per_prompt.append((batches, max_det))
# Main output uses first prompt's tokens (for compatibility)
out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt}
return out
class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip")
def encode_token_weights(self, token_weight_pairs):
per_prompt = token_weight_pairs.pop("sam3_per_prompt", None)
if per_prompt is None:
return super().encode_token_weights(token_weight_pairs)
# Encode each prompt separately, pack into extra dict
inner = getattr(self, self.clip)
multi_cond = []
first_pooled = None
for batches, max_det in per_prompt:
out = inner.encode_token_weights(batches)
cond, pooled = out[0], out[1]
extra = out[2] if len(out) > 2 else {}
if first_pooled is None:
first_pooled = pooled
multi_cond.append({
"cond": cond,
"attention_mask": extra.get("attention_mask"),
"max_detections": max_det,
})
# Return first prompt as main (for non-SAM3 consumers), all prompts in metadata
main = multi_cond[0]
main_extra = {}
if main["attention_mask"] is not None:
main_extra["attention_mask"] = main["attention_mask"]
main_extra["sam3_multi_cond"] = multi_cond
return (main["cond"], first_pooled, main_extra)

View File

@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res memo[obj_id] = res
return res return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -5,12 +5,95 @@ This module handles capability negotiation between frontend and backend,
allowing graceful protocol evolution while maintaining backward compatibility. allowing graceful protocol evolution while maintaining backward compatibility.
""" """
from typing import Any import logging
from typing import Any, TypedDict
from comfy.cli_args import args from comfy.cli_args import args
class FeatureFlagInfo(TypedDict):
type: str
default: Any
description: str
# Registry of known CLI-settable feature flags.
# Launchers can query this via --list-feature-flags to discover valid flags.
CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = {
"show_signin_button": {
"type": "bool",
"default": False,
"description": "Show the sign-in button in the frontend even when not signed in",
},
}
def _coerce_bool(v: str) -> bool:
"""Strict bool coercion: only 'true'/'false' (case-insensitive).
Anything else raises ValueError so the caller can warn and drop the flag,
rather than silently treating typos like 'ture' or 'yes' as False.
"""
lower = v.lower()
if lower == "true":
return True
if lower == "false":
return False
raise ValueError(f"expected 'true' or 'false', got {v!r}")
_COERCE_FNS: dict[str, Any] = {
"bool": _coerce_bool,
"int": lambda v: int(v),
"float": lambda v: float(v),
}
def _coerce_flag_value(key: str, raw_value: str) -> Any:
"""Coerce a raw string value using the registry type, or keep as string.
Returns the raw string if the key is unregistered or the type is unknown.
Raises ValueError/TypeError if the key is registered with a known type but
the value cannot be coerced; callers are expected to warn and drop the flag.
"""
info = CLI_FEATURE_FLAG_REGISTRY.get(key)
if info is None:
return raw_value
coerce = _COERCE_FNS.get(info["type"])
if coerce is None:
return raw_value
return coerce(raw_value)
def _parse_cli_feature_flags() -> dict[str, Any]:
"""Parse --feature-flag key=value pairs from CLI args into a dict.
Items without '=' default to the value 'true' (bare flag form).
Flags whose value cannot be coerced to the registered type are dropped
with a warning, so a typo like '--feature-flag some_bool=ture' does not
silently take effect as the wrong value.
"""
result: dict[str, Any] = {}
for item in getattr(args, "feature_flag", []):
key, sep, raw_value = item.partition("=")
key = key.strip()
if not key:
continue
if not sep:
raw_value = "true"
try:
result[key] = _coerce_flag_value(key, raw_value.strip())
except (ValueError, TypeError) as e:
info = CLI_FEATURE_FLAG_REGISTRY.get(key, {})
logging.warning(
"Could not coerce --feature-flag %s=%r to %s (%s); dropping flag.",
key, raw_value.strip(), info.get("type", "?"), e,
)
return result
# Default server capabilities # Default server capabilities
SERVER_FEATURE_FLAGS: dict[str, Any] = { _CORE_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True, "supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}}, "extension": {"manager": {"supports_v4": True}},
@ -18,6 +101,11 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"assets": args.enable_assets, "assets": args.enable_assets,
} }
# CLI-provided flags cannot overwrite core flags
_cli_flags = {k: v for k, v in _parse_cli_feature_flags().items() if k not in _CORE_FEATURE_FLAGS}
SERVER_FEATURE_FLAGS: dict[str, Any] = {**_CORE_FEATURE_FLAGS, **_cli_flags}
def get_connection_feature( def get_connection_feature(
sockets_metadata: dict[str, dict[str, Any]], sockets_metadata: dict[str, dict[str, Any]],

View File

@ -9,6 +9,7 @@ from comfy_api.latest._input import (
CurveInput, CurveInput,
MonotoneCubicCurve, MonotoneCubicCurve,
LinearCurve, LinearCurve,
RangeInput,
) )
__all__ = [ __all__ = [
@ -21,4 +22,5 @@ __all__ = [
"CurveInput", "CurveInput",
"MonotoneCubicCurve", "MonotoneCubicCurve",
"LinearCurve", "LinearCurve",
"RangeInput",
] ]

View File

@ -1,5 +1,6 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
from .range_types import RangeInput
from .video_types import VideoInput from .video_types import VideoInput
__all__ = [ __all__ = [
@ -12,4 +13,5 @@ __all__ = [
"CurveInput", "CurveInput",
"MonotoneCubicCurve", "MonotoneCubicCurve",
"LinearCurve", "LinearCurve",
"RangeInput",
] ]

View File

@ -0,0 +1,70 @@
from __future__ import annotations
import logging
import math
import numpy as np
logger = logging.getLogger(__name__)
class RangeInput:
"""Represents a levels/range adjustment: input range [min, max] with
optional midpoint (gamma control).
Generates a 1D LUT identical to GIMP's levels mapping:
1. Normalize input to [0, 1] using [min, max]
2. Apply gamma correction: pow(value, 1/gamma)
3. Clamp to [0, 1]
The midpoint field is a position in [0, 1] representing where the
midtone falls within [min, max]. It maps to gamma via:
gamma = -log2(midpoint)
So midpoint=0.5 gamma=1.0 (linear).
"""
def __init__(self, min_val: float, max_val: float, midpoint: float | None = None):
self.min_val = min_val
self.max_val = max_val
self.midpoint = midpoint
@staticmethod
def from_raw(data) -> RangeInput:
if isinstance(data, RangeInput):
return data
if isinstance(data, dict):
return RangeInput(
min_val=float(data.get("min", 0.0)),
max_val=float(data.get("max", 1.0)),
midpoint=float(data["midpoint"]) if data.get("midpoint") is not None else None,
)
raise TypeError(f"Cannot convert {type(data)} to RangeInput")
def to_lut(self, size: int = 256) -> np.ndarray:
"""Generate a float64 lookup table mapping [0, 1] input through this
levels adjustment.
The LUT maps normalized input values (0..1) to output values (0..1),
matching the GIMP levels formula.
"""
xs = np.linspace(0.0, 1.0, size, dtype=np.float64)
in_range = self.max_val - self.min_val
if abs(in_range) < 1e-10:
return np.where(xs >= self.min_val, 1.0, 0.0).astype(np.float64)
# Normalize: map [min, max] → [0, 1]
result = (xs - self.min_val) / in_range
result = np.clip(result, 0.0, 1.0)
# Gamma correction from midpoint
if self.midpoint is not None and self.midpoint > 0 and self.midpoint != 0.5:
gamma = max(-math.log2(self.midpoint), 0.001)
inv_gamma = 1.0 / gamma
mask = result > 0
result[mask] = np.power(result[mask], inv_gamma)
return result
def __repr__(self) -> str:
mid = f", midpoint={self.midpoint}" if self.midpoint is not None else ""
return f"RangeInput(min={self.min_val}, max={self.max_val}{mid})"

View File

@ -12,6 +12,7 @@ import numpy as np
import math import math
import torch import torch
from .._util import VideoContainer, VideoCodec, VideoComponents from .._util import VideoContainer, VideoCodec, VideoComponents
import logging
def container_to_output_format(container_format: str | None) -> str | None: def container_to_output_format(container_format: str | None) -> str | None:
@ -238,64 +239,125 @@ class VideoFromFile(VideoInput):
start_time = max(self._get_raw_duration() + self.__start_time, 0) start_time = max(self._get_raw_duration() + self.__start_time, 0)
else: else:
start_time = self.__start_time start_time = self.__start_time
# Get video frames # Get video frames
frames = [] frames = []
audio_frames = []
alphas = None
start_pts = int(start_time / video_stream.time_base) start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base) end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
for frame in container.decode(video_stream):
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
break
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) if start_pts != 0:
container.seek(start_pts, stream=video_stream)
image_format = 'gbrpf32le'
process_image_format = lambda a: a
audio = None
streams = [video_stream]
has_first_audio_frame = False
checked_alpha = False
# Default to False so we decode until EOF if duration is 0
video_done = False
audio_done = True
if len(container.streams.audio):
audio_stream = container.streams.audio[-1]
streams += [audio_stream]
resampler = av.audio.resampler.AudioResampler(format='fltp')
audio_done = False
for packet in container.demux(*streams):
if video_done and audio_done:
break
if packet.stream.type == "video":
if video_done:
continue
try:
for frame in packet.decode():
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
video_done = True
break
if not checked_alpha:
alpha_channel = False
for comp in frame.format.components:
if comp.is_alpha or frame.format.name == "pal8":
alphas = []
alpha_channel = True
break
if frame.format.name in ("yuvj420p", "yuvj422p", "yuvj444p", "rgb24", "rgba", "pal8"):
process_image_format = lambda a: a.float() / 255.0
if alpha_channel:
image_format = 'rgba'
else:
image_format = 'rgb24'
else:
process_image_format = lambda a: a
if alpha_channel:
image_format = 'gbrapf32le'
else:
image_format = 'gbrpf32le'
checked_alpha = True
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
if frame.rotation != 0:
k = int(round(frame.rotation // 90))
img = np.rot90(img, k=k, axes=(0, 1)).copy()
if alphas is None:
frames.append(torch.from_numpy(img))
else:
frames.append(torch.from_numpy(img[..., :-1]))
alphas.append(torch.from_numpy(img[..., -1:]))
except av.error.InvalidDataError:
logging.info("pyav decode error")
elif packet.stream.type == "audio":
if audio_done:
continue
aframes = itertools.chain.from_iterable(
map(resampler.resample, packet.decode())
)
for frame in aframes:
if self.__duration and frame.time > start_time + self.__duration:
audio_done = True
break
if not has_first_audio_frame:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
if to_skip < frame.samples:
has_first_audio_frame = True
audio_frames.append(frame.to_ndarray()[..., to_skip:])
else:
audio_frames.append(frame.to_ndarray())
images = process_image_format(torch.stack(frames)) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
if alphas is not None:
alphas = process_image_format(torch.stack(alphas)) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1)
# Get frame rate # Get frame rate
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
# Get audio if available if len(audio_frames) > 0:
audio = None audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
container.seek(start_pts, stream=video_stream) if self.__duration:
# Use last stream for consistency audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
if len(container.streams.audio):
audio_stream = container.streams.audio[-1]
audio_frames = []
resample = av.audio.resampler.AudioResampler(format='fltp').resample
frames = itertools.chain.from_iterable(
map(resample, container.decode(audio_stream))
)
has_first_frame = False audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
for frame in frames: audio = AudioInput({
offset_seconds = start_time - frame.pts * audio_stream.time_base "waveform": audio_tensor,
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate)) "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
if to_skip < frame.samples: })
has_first_frame = True
break
if has_first_frame:
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if self.__duration and frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
if self.__duration:
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
})
metadata = container.metadata metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents: def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO): if isinstance(self.__file, io.BytesIO):

View File

@ -1266,6 +1266,43 @@ class Histogram(ComfyTypeIO):
Type = list[int] Type = list[int]
@comfytype(io_type="RANGE")
class Range(ComfyTypeIO):
from comfy_api.input import RangeInput
if TYPE_CHECKING:
Type = RangeInput
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None,
display: str=None,
gradient_stops: list=None,
show_midpoint: bool=None,
midpoint_scale: str=None,
value_min: float=None,
value_max: float=None,
advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {"min": 0.0, "max": 1.0}
self.display = display
self.gradient_stops = gradient_stops
self.show_midpoint = show_midpoint
self.midpoint_scale = midpoint_scale
self.value_min = value_min
self.value_max = value_max
def as_dict(self):
return super().as_dict() | prune_dict({
"display": self.display,
"gradient_stops": self.gradient_stops,
"show_midpoint": self.show_midpoint,
"midpoint_scale": self.midpoint_scale,
"value_min": self.value_min,
"value_max": self.value_max,
})
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -2276,5 +2313,6 @@ __all__ = [
"BoundingBox", "BoundingBox",
"Curve", "Curve",
"Histogram", "Histogram",
"Range",
"NodeReplace", "NodeReplace",
] ]

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional
from .._input import ImageInput, AudioInput from .._input import ImageInput, AudioInput, MaskInput
class VideoCodec(str, Enum): class VideoCodec(str, Enum):
AUTO = "auto" AUTO = "auto"
@ -48,5 +48,4 @@ class VideoComponents:
frame_rate: Fraction frame_rate: Fraction
audio: Optional[AudioInput] = None audio: Optional[AudioInput] = None
metadata: Optional[dict] = None metadata: Optional[dict] = None
alpha: Optional[MaskInput] = None

View File

@ -52,6 +52,26 @@ class TaskImageContent(BaseModel):
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None) role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
class TaskVideoContentUrl(BaseModel):
url: str = Field(...)
class TaskVideoContent(BaseModel):
type: str = Field("video_url")
video_url: TaskVideoContentUrl = Field(...)
role: str = Field("reference_video")
class TaskAudioContentUrl(BaseModel):
url: str = Field(...)
class TaskAudioContent(BaseModel):
type: str = Field("audio_url")
audio_url: TaskAudioContentUrl = Field(...)
role: str = Field("reference_audio")
class Text2VideoTaskCreationRequest(BaseModel): class Text2VideoTaskCreationRequest(BaseModel):
model: str = Field(...) model: str = Field(...)
content: list[TaskTextContent] = Field(..., min_length=1) content: list[TaskTextContent] = Field(..., min_length=1)
@ -64,6 +84,17 @@ class Image2VideoTaskCreationRequest(BaseModel):
generate_audio: bool | None = Field(...) generate_audio: bool | None = Field(...)
class Seedance2TaskCreationRequest(BaseModel):
model: str = Field(...)
content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = Field(..., min_length=1)
generate_audio: bool | None = Field(None)
resolution: str | None = Field(None)
ratio: str | None = Field(None)
duration: int | None = Field(None, ge=4, le=15)
seed: int | None = Field(None, ge=0, le=2147483647)
watermark: bool | None = Field(None)
class TaskCreationResponse(BaseModel): class TaskCreationResponse(BaseModel):
id: str = Field(...) id: str = Field(...)
@ -77,12 +108,67 @@ class TaskStatusResult(BaseModel):
video_url: str = Field(...) video_url: str = Field(...)
class TaskStatusUsage(BaseModel):
completion_tokens: int = Field(0)
total_tokens: int = Field(0)
class TaskStatusResponse(BaseModel): class TaskStatusResponse(BaseModel):
id: str = Field(...) id: str = Field(...)
model: str = Field(...) model: str = Field(...)
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
error: TaskStatusError | None = Field(None) error: TaskStatusError | None = Field(None)
content: TaskStatusResult | None = Field(None) content: TaskStatusResult | None = Field(None)
usage: TaskStatusUsage | None = Field(None)
class GetAssetResponse(BaseModel):
id: str = Field(...)
name: str | None = Field(None)
url: str | None = Field(None)
asset_type: str = Field(...)
group_id: str = Field(...)
status: str = Field(...)
error: TaskStatusError | None = Field(None)
class SeedanceCreateVisualValidateSessionResponse(BaseModel):
session_id: str = Field(...)
h5_link: str = Field(...)
class SeedanceGetVisualValidateSessionResponse(BaseModel):
session_id: str = Field(...)
status: str = Field(...)
group_id: str | None = Field(None)
error_code: str | None = Field(None)
error_message: str | None = Field(None)
class SeedanceCreateAssetRequest(BaseModel):
group_id: str = Field(...)
url: str = Field(...)
asset_type: str = Field(...)
name: str | None = Field(None, max_length=64)
project_name: str | None = Field(None)
class SeedanceCreateAssetResponse(BaseModel):
asset_id: str = Field(...)
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
url: str = Field(..., description="Publicly accessible URL of the image asset to upload.")
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
SEEDANCE2_PRICE_PER_1K_TOKENS = {
("dreamina-seedance-2-0-260128", False): 0.007,
("dreamina-seedance-2-0-260128", True): 0.0043,
("dreamina-seedance-2-0-fast-260128", False): 0.0056,
("dreamina-seedance-2-0-fast-260128", True): 0.0033,
}
RECOMMENDED_PRESETS = [ RECOMMENDED_PRESETS = [
@ -112,6 +198,19 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [
("Custom", None, None), ("Custom", None, None),
] ]
# Seedance 2.0 reference video pixel count limits per model and output resolution.
SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = {
"dreamina-seedance-2-0-260128": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
"1080p": {"min": 409_600, "max": 2_073_600},
},
"dreamina-seedance-2-0-fast-260128": {
"480p": {"min": 409_600, "max": 927_408},
"720p": {"min": 409_600, "max": 927_408},
},
}
# The time in this dictionary are given for 10 seconds duration. # The time in this dictionary are given for 10 seconds duration.
VIDEO_TASKS_EXECUTION_TIME = { VIDEO_TASKS_EXECUTION_TIME = {
"seedance-1-0-lite-t2v-250428": { "seedance-1-0-lite-t2v-250428": {

View File

@ -1,15 +1,12 @@
from __future__ import annotations from __future__ import annotations
import torch
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import Optional, Union
import torch
from pydantic import BaseModel, Field, confloat from pydantic import BaseModel, Field, confloat
class LumaIO: class LumaIO:
LUMA_REF = "LUMA_REF" LUMA_REF = "LUMA_REF"
LUMA_CONCEPTS = "LUMA_CONCEPTS" LUMA_CONCEPTS = "LUMA_CONCEPTS"
@ -183,13 +180,13 @@ class LumaAssets(BaseModel):
class LumaImageRef(BaseModel): class LumaImageRef(BaseModel):
'''Used for image gen''' """Used for image gen"""
url: str = Field(..., description='The URL of the image reference') url: str = Field(..., description='The URL of the image reference')
weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference') weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
class LumaImageReference(BaseModel): class LumaImageReference(BaseModel):
'''Used for video gen''' """Used for video gen"""
type: Optional[str] = Field('image', description='Input type, defaults to image') type: Optional[str] = Field('image', description='Input type, defaults to image')
url: str = Field(..., description='The URL of the image') url: str = Field(..., description='The URL of the image')
@ -251,3 +248,32 @@ class LumaGeneration(BaseModel):
assets: Optional[LumaAssets] = Field(None, description='The assets of the generation') assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
model: str = Field(..., description='The model used for the generation') model: str = Field(..., description='The model used for the generation')
request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation") request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
class Luma2ImageRef(BaseModel):
url: str | None = None
data: str | None = None
media_type: str | None = None
class Luma2GenerationRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=6000)
model: str | None = None
type: str | None = None
aspect_ratio: str | None = None
style: str | None = None
output_format: str | None = None
web_search: bool | None = None
image_ref: list[Luma2ImageRef] | None = None
source: Luma2ImageRef | None = None
class Luma2Generation(BaseModel):
id: str | None = None
type: str | None = None
state: str | None = None
model: str | None = None
created_at: str | None = None
output: list[LumaImageReference] | None = None
failure_reason: str | None = None
failure_code: str | None = None

View File

@ -1,152 +0,0 @@
from enum import Enum
from typing import Optional, Dict, Any
from pydantic import BaseModel, Field, StrictBytes
class MoonvalleyPromptResponse(BaseModel):
error: Optional[Dict[str, Any]] = None
frame_conditioning: Optional[Dict[str, Any]] = None
id: Optional[str] = None
inference_params: Optional[Dict[str, Any]] = None
meta: Optional[Dict[str, Any]] = None
model_params: Optional[Dict[str, Any]] = None
output_url: Optional[str] = None
prompt_text: Optional[str] = None
status: Optional[str] = None
class MoonvalleyTextToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
75, description='Number of cooldown steps (calculated based on num_frames)'
)
fps: Optional[int] = Field(
24, description='Frames per second of the generated video'
)
guidance_scale: Optional[float] = Field(
10, description='Guidance scale for generation control'
)
height: Optional[int] = Field(
1080, description='Height of the generated video in pixels'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
num_frames: Optional[int] = Field(64, description='Number of frames to generate')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
0, description='Number of warmup steps (calculated based on num_frames)'
)
width: Optional[int] = Field(
1920, description='Width of the generated video in pixels'
)
class MoonvalleyTextToVideoRequest(BaseModel):
image_url: Optional[str] = None
inference_params: Optional[MoonvalleyTextToVideoInferenceParams] = None
prompt_text: Optional[str] = None
webhook_url: Optional[str] = None
class MoonvalleyUploadFileRequest(BaseModel):
file: Optional[StrictBytes] = None
class MoonvalleyUploadFileResponse(BaseModel):
access_url: Optional[str] = None
class MoonvalleyVideoToVideoInferenceParams(BaseModel):
add_quality_guidance: Optional[bool] = Field(
True, description='Whether to add quality guidance'
)
caching_coefficient: Optional[float] = Field(
0.3, description='Caching coefficient for optimization'
)
caching_cooldown: Optional[int] = Field(
3, description='Number of caching cooldown steps'
)
caching_warmup: Optional[int] = Field(
3, description='Number of caching warmup steps'
)
clip_value: Optional[float] = Field(
3, description='CLIP value for generation control'
)
conditioning_frame_index: Optional[int] = Field(
0, description='Index of the conditioning frame'
)
cooldown_steps: Optional[int] = Field(
36, description='Number of cooldown steps (calculated based on num_frames)'
)
guidance_scale: Optional[float] = Field(
15, description='Guidance scale for generation control'
)
negative_prompt: Optional[str] = Field(None, description='Negative prompt text')
seed: Optional[int] = Field(
None, description='Random seed for generation (default: random)'
)
shift_value: Optional[float] = Field(
3, description='Shift value for generation control'
)
steps: Optional[int] = Field(80, description='Number of denoising steps')
use_guidance_schedule: Optional[bool] = Field(
True, description='Whether to use guidance scheduling'
)
use_negative_prompts: Optional[bool] = Field(
False, description='Whether to use negative prompts'
)
use_timestep_transform: Optional[bool] = Field(
True, description='Whether to use timestep transformation'
)
warmup_steps: Optional[int] = Field(
24, description='Number of warmup steps (calculated based on num_frames)'
)
class ControlType(str, Enum):
motion_control = 'motion_control'
pose_control = 'pose_control'
class MoonvalleyVideoToVideoRequest(BaseModel):
control_type: ControlType = Field(
..., description='Supported types for video control'
)
inference_params: Optional[MoonvalleyVideoToVideoInferenceParams] = None
prompt_text: str = Field(..., description='Describes the video to generate')
video_url: str = Field(..., description='Url to control video')
webhook_url: Optional[str] = Field(
None, description='Optional webhook URL for notifications'
)

View File

@ -56,14 +56,14 @@ class ModelResponseProperties(BaseModel):
instructions: str | None = Field(None) instructions: str | None = Field(None)
max_output_tokens: int | None = Field(None) max_output_tokens: int | None = Field(None)
model: str | None = Field(None) model: str | None = Field(None)
temperature: float | None = Field(1, description="Controls randomness in the response", ge=0.0, le=2.0) temperature: float | None = Field(None, description="Controls randomness in the response", ge=0.0, le=2.0)
top_p: float | None = Field( top_p: float | None = Field(
1, None,
description="Controls diversity of the response via nucleus sampling", description="Controls diversity of the response via nucleus sampling",
ge=0.0, ge=0.0,
le=1.0, le=1.0,
) )
truncation: str | None = Field("disabled", description="Allowed values: 'auto' or 'disabled'") truncation: str | None = Field(None, description="Allowed values: 'auto' or 'disabled'")
class ResponseProperties(BaseModel): class ResponseProperties(BaseModel):

View File

@ -1,4 +1,4 @@
from typing import Optional, Union from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
grain: Optional[float] = Field(None, description="Grain after AI model processing") grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain") grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
realism: float | None = Field(None, description="ast-2 realism control")
class OutputInformationVideo(BaseModel): class OutputInformationVideo(BaseModel):
@ -90,7 +93,7 @@ class Overrides(BaseModel):
class CreateVideoRequest(BaseModel): class CreateVideoRequest(BaseModel):
source: CreateVideoRequestSource = Field(...) source: CreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
output: OutputInformationVideo = Field(...) output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) overrides: Overrides = Field(Overrides(isPaidDiffusion=True))

View File

@ -118,7 +118,7 @@ class Wan27ReferenceVideoInputField(BaseModel):
class Wan27ReferenceVideoParametersField(BaseModel): class Wan27ReferenceVideoParametersField(BaseModel):
resolution: str = Field(...) resolution: str = Field(...)
ratio: str | None = Field(None) ratio: str | None = Field(None)
duration: int = Field(5, ge=2, le=10) duration: int = Field(5, ge=2, le=15)
watermark: bool = Field(False) watermark: bool = Field(False)
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
@ -157,7 +157,7 @@ class Wan27VideoEditInputField(BaseModel):
class Wan27VideoEditParametersField(BaseModel): class Wan27VideoEditParametersField(BaseModel):
resolution: str = Field(...) resolution: str = Field(...)
ratio: str | None = Field(None) ratio: str | None = Field(None)
duration: int = Field(0) duration: int | None = Field(0)
audio_setting: str = Field("auto") audio_setting: str = Field("auto")
watermark: bool = Field(False) watermark: bool = Field(False)
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)

File diff suppressed because it is too large Load Diff

View File

@ -178,7 +178,6 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
status_extractor=lambda x: x.data.status, status_extractor=lambda x: x.data.status,
price_extractor=lambda x: request_price, price_extractor=lambda x: request_price,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480,
) )
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url)) return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.res_url))
@ -324,7 +323,6 @@ class HitPawVideoEnhance(IO.ComfyNode):
status_extractor=lambda x: x.data.status, status_extractor=lambda x: x.data.status,
price_extractor=lambda x: request_price, price_extractor=lambda x: request_price,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=320,
) )
return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.res_url))

View File

@ -221,14 +221,17 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
obj_result = None
if obj_file_response:
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
return IO.NodeOutput( return IO.NodeOutput(
f"{task_id}.glb", f"{task_id}.glb",
await download_url_to_file_3d( await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
), ),
obj_result.obj, obj_result.obj if obj_result else None,
obj_result.texture, obj_result.texture if obj_result else None,
) )
@ -378,17 +381,30 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url) obj_file_response = get_file_from_response(result.ResultFile3Ds, "obj", raise_if_not_found=False)
if obj_file_response:
obj_result = await download_and_extract_obj_zip(obj_file_response.Url)
return IO.NodeOutput(
f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
obj_result.obj,
obj_result.texture,
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
)
return IO.NodeOutput( return IO.NodeOutput(
f"{task_id}.glb", f"{task_id}.glb",
await download_url_to_file_3d( await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
), ),
obj_result.obj, None,
obj_result.texture, None,
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3), None,
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3), None,
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3), None,
) )

View File

@ -862,7 +862,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
), ),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"storyboards", "storyboards",
options=[ options=[
@ -904,12 +904,13 @@ class OmniProTextToVideoNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr=""" expr="""
( (
$mode := (widgets.resolution = "720p") ? "std" : "pro"; $res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3"); $isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio; $audio := $isV3 and widgets.generate_audio;
$rates := $audio $rates := $audio
? {"std": 0.112, "pro": 0.14} ? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112}; : {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
) )
""", """,
@ -934,6 +935,8 @@ class OmniProTextToVideoNode(IO.ComfyNode):
raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.") raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.")
if generate_audio: if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.") raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1": if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.") raise ValueError("kling-video-o1 does not support storyboards.")
@ -963,6 +966,12 @@ class OmniProTextToVideoNode(IO.ComfyNode):
f"must equal the global duration ({duration}s)." f"must equal the global duration ({duration}s)."
) )
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -972,7 +981,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
prompt=prompt, prompt=prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
duration=str(duration), duration=str(duration),
mode="pro" if resolution == "1080p" else "std", mode=mode,
multi_shot=multi_shot, multi_shot=multi_shot,
multi_prompt=multi_prompt_list, multi_prompt=multi_prompt_list,
shot_type="customize" if multi_shot else None, shot_type="customize" if multi_shot else None,
@ -1014,7 +1023,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
optional=True, optional=True,
tooltip="Up to 6 additional reference images.", tooltip="Up to 6 additional reference images.",
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"storyboards", "storyboards",
options=[ options=[
@ -1061,12 +1070,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr=""" expr="""
( (
$mode := (widgets.resolution = "720p") ? "std" : "pro"; $res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3"); $isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio; $audio := $isV3 and widgets.generate_audio;
$rates := $audio $rates := $audio
? {"std": 0.112, "pro": 0.14} ? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112}; : {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
) )
""", """,
@ -1093,6 +1103,8 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
if generate_audio: if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.") raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1": if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.") raise ValueError("kling-video-o1 does not support storyboards.")
@ -1161,6 +1173,12 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
image_list.append(OmniParamImage(image_url=i)) image_list.append(OmniParamImage(image_url=i))
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -1170,7 +1188,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
prompt=prompt, prompt=prompt,
duration=str(duration), duration=str(duration),
image_list=image_list, image_list=image_list,
mode="pro" if resolution == "1080p" else "std", mode=mode,
sound="on" if generate_audio else "off", sound="on" if generate_audio else "off",
multi_shot=multi_shot, multi_shot=multi_shot,
multi_prompt=multi_prompt_list, multi_prompt=multi_prompt_list,
@ -1204,7 +1222,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
"reference_images", "reference_images",
tooltip="Up to 7 reference images.", tooltip="Up to 7 reference images.",
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"storyboards", "storyboards",
options=[ options=[
@ -1251,12 +1269,13 @@ class OmniProImageToVideoNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr=""" expr="""
( (
$mode := (widgets.resolution = "720p") ? "std" : "pro"; $res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3"); $isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio; $audio := $isV3 and widgets.generate_audio;
$rates := $audio $rates := $audio
? {"std": 0.112, "pro": 0.14} ? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112}; : {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
) )
""", """,
@ -1282,6 +1301,8 @@ class OmniProImageToVideoNode(IO.ComfyNode):
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
if generate_audio: if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.") raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1": if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.") raise ValueError("kling-video-o1 does not support storyboards.")
@ -1320,6 +1341,12 @@ class OmniProImageToVideoNode(IO.ComfyNode):
image_list: list[OmniParamImage] = [] image_list: list[OmniParamImage] = []
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i)) image_list.append(OmniParamImage(image_url=i))
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -1330,7 +1357,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
duration=str(duration), duration=str(duration),
image_list=image_list, image_list=image_list,
mode="pro" if resolution == "1080p" else "std", mode=mode,
sound="on" if generate_audio else "off", sound="on" if generate_audio else "off",
multi_shot=multi_shot, multi_shot=multi_shot,
multi_prompt=multi_prompt_list, multi_prompt=multi_prompt_list,
@ -2860,7 +2887,7 @@ class KlingVideoNode(IO.ComfyNode):
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"kling-v3", "kling-v3",
[ [
IO.Combo.Input("resolution", options=["1080p", "720p"]), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
IO.Combo.Input( IO.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=["16:9", "9:16", "1:1"], options=["16:9", "9:16", "1:1"],
@ -2913,7 +2940,11 @@ class KlingVideoNode(IO.ComfyNode):
), ),
expr=""" expr="""
( (
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; $rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off"; $audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio); $rate := $lookup($lookup($rates, $res), $audio);
@ -2943,7 +2974,12 @@ class KlingVideoNode(IO.ComfyNode):
start_frame: Input.Image | None = None, start_frame: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
_ = seed _ = seed
mode = "pro" if model["resolution"] == "1080p" else "std" if model["resolution"] == "4k":
mode = "4k"
elif model["resolution"] == "1080p":
mode = "pro"
else:
mode = "std"
custom_multi_shot = False custom_multi_shot = False
if multi_shot["multi_shot"] == "disabled": if multi_shot["multi_shot"] == "disabled":
shot_type = None shot_type = None
@ -3057,7 +3093,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"kling-v3", "kling-v3",
[ [
IO.Combo.Input("resolution", options=["1080p", "720p"]), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
], ],
), ),
], ],
@ -3089,7 +3125,11 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
), ),
expr=""" expr="""
( (
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; $rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off"; $audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio); $rate := $lookup($lookup($rates, $res), $audio);
@ -3118,6 +3158,12 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame") image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame")
image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame")
if model["resolution"] == "4k":
mode = "4k"
elif model["resolution"] == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
@ -3127,7 +3173,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
image=image_url, image=image_url,
image_tail=image_tail_url, image_tail=image_tail_url,
prompt=prompt, prompt=prompt,
mode="pro" if model["resolution"] == "1080p" else "std", mode=mode,
duration=str(duration), duration=str(duration),
sound="on" if generate_audio else "off", sound="on" if generate_audio else "off",
), ),

View File

@ -1,10 +1,11 @@
from typing import Optional
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.luma import ( from comfy_api_nodes.apis.luma import (
Luma2Generation,
Luma2GenerationRequest,
Luma2ImageRef,
LumaAspectRatio, LumaAspectRatio,
LumaCharacterRef, LumaCharacterRef,
LumaConceptChain, LumaConceptChain,
@ -30,6 +31,7 @@ from comfy_api_nodes.util import (
download_url_to_video_output, download_url_to_video_output,
poll_op, poll_op,
sync_op, sync_op,
upload_image_to_comfyapi,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_string, validate_string,
) )
@ -212,9 +214,9 @@ class LumaImageGenerationNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
seed, seed,
style_image_weight: float, style_image_weight: float,
image_luma_ref: Optional[LumaReferenceChain] = None, image_luma_ref: LumaReferenceChain | None = None,
style_image: Optional[torch.Tensor] = None, style_image: torch.Tensor | None = None,
character_image: Optional[torch.Tensor] = None, character_image: torch.Tensor | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3) validate_string(prompt, strip_whitespace=True, min_length=3)
# handle image_luma_ref # handle image_luma_ref
@ -434,7 +436,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
duration: str, duration: str,
loop: bool, loop: bool,
seed, seed,
luma_concepts: Optional[LumaConceptChain] = None, luma_concepts: LumaConceptChain | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3) validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
@ -533,7 +535,6 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=PRICE_BADGE_VIDEO, price_badge=PRICE_BADGE_VIDEO,
) )
@classmethod @classmethod
@ -644,6 +645,293 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
) )
def _luma2_uni1_common_inputs(max_image_refs: int) -> list:
return [
IO.Combo.Input(
"style",
options=["auto", "manga"],
default="auto",
tooltip="Style preset. 'auto' picks based on the prompt; "
"'manga' applies a manga/anime aesthetic and requires a portrait "
"aspect ratio (2:3, 9:16, 1:2, 1:3).",
),
IO.Boolean.Input(
"web_search",
default=False,
tooltip="Search the web for visual references before generating.",
),
IO.Autogrow.Input(
"image_ref",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, max_image_refs + 1)],
min=0,
),
optional=True,
tooltip=f"Up to {max_image_refs} reference images for style/content guidance.",
),
]
async def _luma2_upload_image_refs(
cls: type[IO.ComfyNode],
refs: dict | None,
max_count: int,
) -> list[Luma2ImageRef] | None:
if not refs:
return None
out: list[Luma2ImageRef] = []
for key in refs:
url = await upload_image_to_comfyapi(cls, refs[key])
out.append(Luma2ImageRef(url=url))
if len(out) > max_count:
raise ValueError(f"Maximum {max_count} reference images are allowed.")
return out or None
async def _luma2_submit_and_poll(
cls: type[IO.ComfyNode],
request: Luma2GenerationRequest,
) -> Input.Image:
initial = await sync_op(
cls,
ApiEndpoint(path="/proxy/luma_2/generations", method="POST"),
response_model=Luma2Generation,
data=request,
)
if not initial.id:
raise RuntimeError("Luma 2 API did not return a generation id.")
final = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/luma_2/generations/{initial.id}", method="GET"),
response_model=Luma2Generation,
status_extractor=lambda r: r.state,
progress_extractor=lambda r: None,
)
if not final.output:
msg = final.failure_reason or "no output returned"
raise RuntimeError(f"Luma 2 generation failed: {msg}")
url = final.output[0].url
if not url:
raise RuntimeError("Luma 2 generation completed without an output URL.")
return await download_url_to_image_tensor(url)
class LumaImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageNode2",
display_name="Luma UNI-1 Image",
category="api node/image/Luma",
description="Generate images from text using the Luma UNI-1 model.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. 16000 characters.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"uni-1",
[
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"3:1",
"2:1",
"16:9",
"3:2",
"1:1",
"2:3",
"9:16",
"1:2",
"1:3",
],
default="auto",
tooltip="Output image aspect ratio. 'auto' lets "
"the model pick based on the prompt.",
),
*_luma2_uni1_common_inputs(max_image_refs=9),
],
),
IO.DynamicCombo.Option(
"uni-1-max",
[
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"3:1",
"2:1",
"16:9",
"3:2",
"1:1",
"2:3",
"9:16",
"1:2",
"1:3",
],
default="auto",
tooltip="Output image aspect ratio. 'auto' lets "
"the model pick based on the prompt.",
),
*_luma2_uni1_common_inputs(max_image_refs=9),
],
),
],
tooltip="Model to use for generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]),
expr="""
(
$m := widgets.model;
$refs := $lookup(inputGroups, "model.image_ref");
$base := $m = "uni-1-max" ? 0.1 : 0.0404;
{"type":"usd","usd": $round($base + 0.003 * $refs, 4)}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=6000)
aspect_ratio = model["aspect_ratio"]
style = model["style"]
allowed_manga_ratios = {"2:3", "9:16", "1:2", "1:3"}
if style == "manga" and aspect_ratio != "auto" and aspect_ratio not in allowed_manga_ratios:
raise ValueError(
f"'manga' style requires a portrait aspect ratio "
f"({', '.join(sorted(allowed_manga_ratios))}) or 'auto'; got '{aspect_ratio}'."
)
request = Luma2GenerationRequest(
prompt=prompt,
model=model["model"],
type="image",
aspect_ratio=aspect_ratio if aspect_ratio != "auto" else None,
style=style if style != "auto" else None,
output_format="png",
web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=9),
)
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
class LumaImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="LumaImageEditNode2",
display_name="Luma UNI-1 Image Edit",
category="api node/image/Luma",
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
inputs=[
IO.Image.Input(
"source",
tooltip="Source image to edit.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Description of the desired edit. 16000 characters.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"uni-1",
_luma2_uni1_common_inputs(max_image_refs=8),
),
IO.DynamicCombo.Option(
"uni-1-max",
_luma2_uni1_common_inputs(max_image_refs=8),
),
],
tooltip="Model to use for editing.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"], input_groups=["model.image_ref"]),
expr="""
(
$m := widgets.model;
$refs := $lookup(inputGroups, "model.image_ref");
$base := $m = "uni-1-max" ? 0.103 : 0.0434;
{"type":"usd","usd": $round($base + 0.003 * $refs, 4)}
)
""",
),
)
@classmethod
async def execute(
cls,
source: Input.Image,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=6000)
request = Luma2GenerationRequest(
prompt=prompt,
model=model["model"],
type="image_edit",
source=Luma2ImageRef(url=await upload_image_to_comfyapi(cls, source)),
style=model["style"] if model["style"] != "auto" else None,
output_format="png",
web_search=model["web_search"],
image_ref=await _luma2_upload_image_refs(cls, model.get("image_ref"), max_count=8),
)
return IO.NodeOutput(await _luma2_submit_and_poll(cls, request))
class LumaExtension(ComfyExtension): class LumaExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -654,6 +942,8 @@ class LumaExtension(ComfyExtension):
LumaImageToVideoGenerationNode, LumaImageToVideoGenerationNode,
LumaReferenceNode, LumaReferenceNode,
LumaConceptsNode, LumaConceptsNode,
LumaImageNode,
LumaImageEditNode,
] ]

View File

@ -230,7 +230,6 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd, price_extractor=lambda _: price_usd,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480,
) )
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
@ -391,7 +390,6 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
price_extractor=lambda _: price_usd, price_extractor=lambda _: price_usd,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480,
) )
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
@ -541,7 +539,6 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
response_model=TaskResponse, response_model=TaskResponse,
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480,
) )
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
@ -782,7 +779,6 @@ class MagnificImageRelightNode(IO.ComfyNode):
response_model=TaskResponse, response_model=TaskResponse,
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480,
) )
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))
@ -924,7 +920,6 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
response_model=TaskResponse, response_model=TaskResponse,
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
poll_interval=10.0, poll_interval=10.0,
max_poll_attempts=480,
) )
return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0])) return IO.NodeOutput(await download_url_to_image_tensor(final_response.generated[0]))

View File

@ -1,534 +0,0 @@
import logging
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.moonvalley import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
MoonvalleyTextToVideoRequest,
MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output,
poll_op,
sync_op,
trim_video,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_container_format_is_mp4,
validate_image_dimensions,
validate_string,
)
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
API_TXT2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/text-to-video"
API_IMG2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/image-to-video"
MIN_WIDTH = 300
MIN_HEIGHT = 300
MAX_WIDTH = 10000
MAX_HEIGHT = 10000
MIN_VID_WIDTH = 300
MIN_VID_HEIGHT = 300
MAX_VID_WIDTH = 10000
MAX_VID_HEIGHT = 10000
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
"""Verifies that the initial response contains a task ID."""
return bool(response.id)
def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg)
raise RuntimeError(error_msg)
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.
Args:
video: Input video to validate
Returns:
Validated and potentially trimmed video
Raises:
ValueError: If video doesn't meet requirements
MoonvalleyApiError: If video duration is too short
"""
width, height = _get_video_dimensions(video)
_validate_video_dimensions(width, height)
validate_container_format_is_mp4(video)
return _validate_and_trim_duration(video)
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
raise ValueError(f"Cannot get video dimensions: {e}") from e
def _validate_video_dimensions(width: int, height: int) -> None:
"""Validates video dimensions meet Moonvalley V2V requirements."""
supported_resolutions = {
(1920, 1080),
(1080, 1920),
(1152, 1152),
(1536, 1152),
(1152, 1536),
}
if (width, height) not in supported_resolutions:
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
return _trim_if_too_long(video, duration)
def _validate_minimum_duration(duration: float) -> None:
"""Ensures video is at least 5 seconds long."""
if duration < 5:
raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
return video
def parse_width_height_from_res(resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = {
"16:9 (1920 x 1080)": {"width": 1920, "height": 1080},
"9:16 (1080 x 1920)": {"width": 1080, "height": 1920},
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
}
return res_map.get(resolution, {"width": 1920, "height": 1080})
def parse_control_parameter(value):
control_map = {
"Motion Transfer": "motion_control",
"Canny": "canny_control",
"Pose Transfer": "pose_control",
"Depth": "depth_control",
}
return control_map.get(value, control_map["Motion Transfer"])
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
return await poll_op(
cls,
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
response_model=MoonvalleyPromptResponse,
status_extractor=lambda r: (r.status if r and r.status else None),
poll_interval=16.0,
max_poll_attempts=240,
)
class MoonvalleyImg2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyImg2VideoNode",
display_name="Moonvalley Marey Image to Video",
category="api node/video/Moonvalley Marey",
description="Moonvalley Marey Image to Video Node",
inputs=[
IO.Image.Input(
"image",
tooltip="The reference image used to generate the video",
),
IO.String.Input(
"prompt",
multiline=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
IO.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
# "21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
IO.Float.Input(
"prompt_adherence",
default=4.5,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=True,
),
IO.Int.Input(
"steps",
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Number of denoising steps",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 1.5}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
prompt: str,
negative_prompt: str,
resolution: str,
prompt_adherence: float,
seed: int,
steps: int,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
width=width_height["width"],
height=width_height["height"],
use_negative_prompts=True,
)
# Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png"
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
),
)
validate_task_creation_response(task_creation_response)
final_response = await get_response(cls, task_creation_response.id)
video = await download_url_to_video_output(final_response.output_url)
return IO.NodeOutput(video)
class MoonvalleyVideo2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyVideo2VideoNode",
display_name="Moonvalley Marey Video to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Describes the video to generate",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Random seed value",
control_after_generate=False,
),
IO.Video.Input(
"video",
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
),
IO.Combo.Input(
"control_type",
options=["Motion Transfer", "Pose Transfer"],
default="Motion Transfer",
optional=True,
),
IO.Int.Input(
"motion_intensity",
default=100,
min=0,
max=100,
step=1,
tooltip="Only used if control_type is 'Motion Transfer'",
optional=True,
),
IO.Int.Input(
"steps",
default=60,
min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
tooltip="Number of inference steps",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 2.25}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
seed: int,
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: int | None = 100,
steps=60,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(cls, validated_video)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
# Only include motion_intensity for Motion Transfer
control_params = {}
if control_type == "Motion Transfer" and motion_intensity is not None:
control_params["motion_intensity"] = motion_intensity
inference_params = MoonvalleyVideoToVideoInferenceParams(
negative_prompt=negative_prompt,
seed=seed,
control_params=control_params,
steps=steps,
guidance_scale=prompt_adherence,
)
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyVideoToVideoRequest(
control_type=parse_control_parameter(control_type),
video_url=video_url,
prompt_text=prompt,
inference_params=inference_params,
),
)
validate_task_creation_response(task_creation_response)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="MoonvalleyTxt2VideoNode",
display_name="Moonvalley Marey Text to Video",
category="api node/video/Moonvalley Marey",
description="",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
tooltip="Negative prompt text",
),
IO.Combo.Input(
"resolution",
options=[
"16:9 (1920 x 1080)",
"9:16 (1080 x 1920)",
"1:1 (1152 x 1152)",
"4:3 (1536 x 1152)",
"3:4 (1152 x 1536)",
"21:9 (2560 x 1080)",
],
default="16:9 (1920 x 1080)",
tooltip="Resolution of the output video",
),
IO.Float.Input(
"prompt_adherence",
default=4.0,
min=1.0,
max=20.0,
step=1.0,
tooltip="Guidance scale for generation control",
),
IO.Int.Input(
"seed",
default=9,
min=0,
max=4294967295,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Random seed value",
),
IO.Int.Input(
"steps",
default=80,
min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Inference steps",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(),
expr="""{"type":"usd","usd": 1.5}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
resolution: str,
prompt_adherence: float,
seed: int,
steps: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution)
inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt,
steps=steps,
seed=seed,
guidance_scale=prompt_adherence,
num_frames=128,
width=width_height["width"],
height=width_height["height"],
)
task_creation_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
response_model=MoonvalleyPromptResponse,
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
)
validate_task_creation_response(task_creation_response)
final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
class MoonvalleyExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
MoonvalleyImg2VideoNode,
MoonvalleyTxt2VideoNode,
MoonvalleyVideo2VideoNode,
]
async def comfy_entrypoint() -> MoonvalleyExtension:
return MoonvalleyExtension()

View File

@ -39,16 +39,18 @@ STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
class SupportedOpenAIModel(str, Enum): class SupportedOpenAIModel(str, Enum):
o4_mini = "o4-mini" gpt_5_5_pro = "gpt-5.5-pro"
o1 = "o1" gpt_5_5 = "gpt-5.5"
o3 = "o3"
o1_pro = "o1-pro"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
gpt_5 = "gpt-5" gpt_5 = "gpt-5"
gpt_5_mini = "gpt-5-mini" gpt_5_mini = "gpt-5-mini"
gpt_5_nano = "gpt-5-nano" gpt_5_nano = "gpt-5-nano"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
o4_mini = "o4-mini"
o3 = "o3"
o1_pro = "o1-pro"
o1 = "o1"
async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
@ -357,13 +359,17 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) ->
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0 return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0
def calculate_tokens_price_image_2_0(response: OpenAIImageGenerationResponse) -> float | None:
return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0
class OpenAIGPTImage1(IO.ComfyNode): class OpenAIGPTImage1(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="OpenAIGPTImage1", node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 1.5", display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI", category="api node/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image endpoint.", description="Generates images synchronously via OpenAI's GPT Image endpoint.",
inputs=[ inputs=[
@ -401,8 +407,19 @@ class OpenAIGPTImage1(IO.ComfyNode):
IO.Combo.Input( IO.Combo.Input(
"size", "size",
default="auto", default="auto",
options=["auto", "1024x1024", "1024x1536", "1536x1024"], options=[
tooltip="Image size", "auto",
"1024x1024",
"1024x1536",
"1536x1024",
"2048x2048",
"2048x1152",
"1152x2048",
"3840x2160",
"2160x3840",
"Custom",
],
tooltip="Image size. Select 'Custom' to use the custom width and height (GPT Image 2 only).",
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
@ -427,8 +444,26 @@ class OpenAIGPTImage1(IO.ComfyNode):
), ),
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=["gpt-image-1", "gpt-image-1.5"], options=["gpt-image-1", "gpt-image-1.5", "gpt-image-2"],
default="gpt-image-1.5", default="gpt-image-2",
optional=True,
),
IO.Int.Input(
"custom_width",
default=1024,
min=1024,
max=3840,
step=16,
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).",
optional=True,
),
IO.Int.Input(
"custom_height",
default=1024,
min=1024,
max=3840,
step=16,
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).",
optional=True, optional=True,
), ),
], ],
@ -442,23 +477,36 @@ class OpenAIGPTImage1(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]),
expr=""" expr="""
( (
$ranges := { $ranges := {
"low": [0.011, 0.02], "gpt-image-1": {
"medium": [0.046, 0.07], "low": [0.011, 0.02],
"high": [0.167, 0.3] "medium": [0.042, 0.07],
"high": [0.167, 0.25]
},
"gpt-image-1.5": {
"low": [0.009, 0.02],
"medium": [0.034, 0.062],
"high": [0.133, 0.22]
},
"gpt-image-2": {
"low": [0.0048, 0.019],
"medium": [0.041, 0.168],
"high": [0.165, 0.67]
}
}; };
$range := $lookup($ranges, widgets.quality); $range := $lookup($lookup($ranges, widgets.model), widgets.quality);
$n := widgets.n; $nRaw := widgets.n;
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
($n = 1) ($n = 1)
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
: { : {
"type":"range_usd", "type":"range_usd",
"min_usd": $range[0], "min_usd": $range[0] * $n,
"max_usd": $range[1], "max_usd": $range[1] * $n,
"format": { "suffix": " x " & $string($n) & "/Run" } "format": { "suffix": "/Run", "approximate": true }
} }
) )
""", """,
@ -476,6 +524,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
mask: Input.Image | None = None, mask: Input.Image | None = None,
n: int = 1, n: int = 1,
size: str = "1024x1024", size: str = "1024x1024",
custom_width: int = 1024,
custom_height: int = 1024,
model: str = "gpt-image-1", model: str = "gpt-image-1",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
@ -483,10 +533,36 @@ class OpenAIGPTImage1(IO.ComfyNode):
if mask is not None and image is None: if mask is not None and image is None:
raise ValueError("Cannot use a mask without an input image") raise ValueError("Cannot use a mask without an input image")
if size == "Custom":
if model != "gpt-image-2":
raise ValueError("Custom resolution is only supported by GPT Image 2 model")
if custom_width % 16 != 0 or custom_height % 16 != 0:
raise ValueError(f"Custom width and height must be multiples of 16, got {custom_width}x{custom_height}")
if max(custom_width, custom_height) > 3840:
raise ValueError(f"Custom resolution max edge must be <= 3840, got {custom_width}x{custom_height}")
ratio = max(custom_width, custom_height) / min(custom_width, custom_height)
if ratio > 3:
raise ValueError(
f"Custom resolution aspect ratio must not exceed 3:1, got {custom_width}x{custom_height}"
)
total_pixels = custom_width * custom_height
if not 655_360 <= total_pixels <= 8_294_400:
raise ValueError(
f"Custom resolution total pixels must be between 655,360 and 8,294,400, got {total_pixels}"
)
size = f"{custom_width}x{custom_height}"
elif model in ("gpt-image-1", "gpt-image-1.5"):
if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"):
raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model")
if model == "gpt-image-1": if model == "gpt-image-1":
price_extractor = calculate_tokens_price_image_1 price_extractor = calculate_tokens_price_image_1
elif model == "gpt-image-1.5": elif model == "gpt-image-1.5":
price_extractor = calculate_tokens_price_image_1_5 price_extractor = calculate_tokens_price_image_1_5
elif model == "gpt-image-2":
price_extractor = calculate_tokens_price_image_2_0
if background == "transparent":
raise ValueError("Transparent background is not supported for GPT Image 2 model")
else: else:
raise ValueError(f"Unknown model: {model}") raise ValueError(f"Unknown model: {model}")
@ -665,6 +741,16 @@ class OpenAIChatNode(IO.ComfyNode):
"usd": [0.002, 0.008], "usd": [0.002, 0.008],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
} }
: $contains($m, "gpt-5.5-pro") ? {
"type": "list_usd",
"usd": [0.03, 0.18],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5.5") ? {
"type": "list_usd",
"usd": [0.005, 0.03],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-5-nano") ? { : $contains($m, "gpt-5-nano") ? {
"type": "list_usd", "type": "list_usd",
"usd": [0.00005, 0.0004], "usd": [0.00005, 0.0004],

Some files were not shown because too many files have changed in this diff Show More