Merge remote-tracking branch 'upstream/master' into gemma4

This commit is contained in:
kijai 2026-04-30 14:29:14 +03:00
commit cbaa07bf05
72 changed files with 35740 additions and 1285 deletions

View File

@ -0,0 +1,45 @@
name: Tag Dispatch to Cloud
on:
push:
tags:
- 'v*'
jobs:
dispatch-cloud:
runs-on: ubuntu-latest
steps:
- name: Send repository dispatch to cloud
env:
DISPATCH_TOKEN: ${{ secrets.CLOUD_REPO_DISPATCH_TOKEN }}
RELEASE_TAG: ${{ github.ref_name }}
run: |
set -euo pipefail
if [ -z "${DISPATCH_TOKEN:-}" ]; then
echo "::error::CLOUD_REPO_DISPATCH_TOKEN is required but not set."
exit 1
fi
RELEASE_URL="https://github.com/${{ github.repository }}/releases/tag/${RELEASE_TAG}"
PAYLOAD="$(jq -n \
--arg release_tag "$RELEASE_TAG" \
--arg release_url "$RELEASE_URL" \
'{
event_type: "comfyui_tag_pushed",
client_payload: {
release_tag: $release_tag,
release_url: $release_url
}
}')"
curl -fsSL \
-X POST \
-H "Accept: application/vnd.github+json" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer ${DISPATCH_TOKEN}" \
https://api.github.com/repos/Comfy-Org/cloud/dispatches \
-d "$PAYLOAD"
echo "✅ Dispatched ComfyUI tag ${RELEASE_TAG} to Comfy-Org/cloud"

1
.gitignore vendored
View File

@ -21,6 +21,5 @@ venv*/
*.log *.log
web_custom_versions/ web_custom_versions/
.DS_Store .DS_Store
openapi.yaml
filtered-openapi.yaml filtered-openapi.yaml
uv.lock uv.lock

View File

@ -2,7 +2,6 @@
precision mediump float; precision mediump float;
uniform sampler2D u_image0; uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blend mode uniform int u_int0; // Blend mode
uniform int u_int1; // Color tint uniform int u_int1; // Color tint
uniform float u_float0; // Intensity uniform float u_float0; // Intensity
@ -75,7 +74,7 @@ void main() {
float t0 = threshold - 0.15; float t0 = threshold - 0.15;
float t1 = threshold + 0.15; float t1 = threshold + 0.15;
vec2 texelSize = 1.0 / u_resolution; vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radius2 = radius * radius; float radius2 = radius * radius;
float sampleScale = clamp(radius * 0.75, 0.35, 1.0); float sampleScale = clamp(radius * 0.75, 0.35, 1.0);

View File

@ -12,7 +12,6 @@ const int RADIAL_SAMPLES = 12;
const float RADIAL_STRENGTH = 0.0003; const float RADIAL_STRENGTH = 0.0003;
uniform sampler2D u_image0; uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL) uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)
uniform float u_float0; // Blur radius/amount uniform float u_float0; // Blur radius/amount
uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical) uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)
@ -25,7 +24,7 @@ float gaussian(float x, float sigma) {
} }
void main() { void main() {
vec2 texelSize = 1.0 / u_resolution; vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));
float radius = max(u_float0, 0.0); float radius = max(u_float0, 0.0);
// Radial (angular) blur - single pass, doesn't use separable // Radial (angular) blur - single pass, doesn't use separable

View File

@ -2,14 +2,13 @@
precision highp float; precision highp float;
uniform sampler2D u_image0; uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // strength [0.0 2.0] typical: 0.31.0 uniform float u_float0; // strength [0.0 2.0] typical: 0.31.0
in vec2 v_texCoord; in vec2 v_texCoord;
layout(location = 0) out vec4 fragColor0; layout(location = 0) out vec4 fragColor0;
void main() { void main() {
vec2 texel = 1.0 / u_resolution; vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));
// Sample center and neighbors // Sample center and neighbors
vec4 center = texture(u_image0, v_texCoord); vec4 center = texture(u_image0, v_texCoord);

View File

@ -2,7 +2,6 @@
precision highp float; precision highp float;
uniform sampler2D u_image0; uniform sampler2D u_image0;
uniform vec2 u_resolution;
uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5 uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5
uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels
uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen
@ -19,7 +18,7 @@ float getLuminance(vec3 color) {
} }
void main() { void main() {
vec2 texel = 1.0 / u_resolution; vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));
float radius = max(u_float1, 0.5); float radius = max(u_float1, 0.5);
float amount = u_float0; float amount = u_float0;
float threshold = u_float2; float threshold = u_float2;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -160,7 +160,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Depth to Image (Z-Image-Turbo)", "name": "Depth to Image (Z-Image-Turbo)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [

View File

@ -261,7 +261,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Depth to Video (LTX 2.0)", "name": "Depth to Video (LTX 2.0)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [

File diff suppressed because it is too large Load Diff

View File

@ -268,7 +268,7 @@
"Node name for S&R": "GLSLShader" "Node name for S&R": "GLSLShader"
}, },
"widgets_values": [ "widgets_values": [
"#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / u_resolution;\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}", "#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}",
"from_input" "from_input"
] ]
}, },

View File

@ -331,7 +331,7 @@
"Node name for S&R": "GLSLShader" "Node name for S&R": "GLSLShader"
}, },
"widgets_values": [ "widgets_values": [
"#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / u_resolution;\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n", "#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n",
"from_input" "from_input"
] ]
} }

File diff suppressed because it is too large Load Diff

View File

@ -128,7 +128,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Image Edit (Flux.2 Klein 4B)", "name": "Image Edit (Flux.2 Klein 4B)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -124,7 +124,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Image Inpainting (Qwen-image)", "name": "Image Inpainting (Qwen-image)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [

View File

@ -204,7 +204,7 @@
}, },
"revision": 0, "revision": 0,
"config": {}, "config": {},
"name": "local-Image Outpainting (Qwen-Image)", "name": "Image Outpainting (Qwen-Image)",
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [

View File

@ -1,15 +1,14 @@
{ {
"id": "1a761372-7c82-4016-b9bf-fa285967e1e9",
"revision": 0, "revision": 0,
"last_node_id": 83, "last_node_id": 176,
"last_link_id": 0, "last_link_id": 0,
"nodes": [ "nodes": [
{ {
"id": 83, "id": 176,
"type": "f754a936-daaf-4b6e-9658-41fdc54d301d", "type": "2d2e3c8e-53b3-4618-be52-6d1d99382f0e",
"pos": [ "pos": [
61.999827823554256, -1150,
153.3332507624185 200
], ],
"size": [ "size": [
400, 400,
@ -56,6 +55,38 @@
"name": "layers" "name": "layers"
}, },
"link": null "link": null
},
{
"name": "seed",
"type": "INT",
"widget": {
"name": "seed"
},
"link": null
},
{
"name": "unet_name",
"type": "COMBO",
"widget": {
"name": "unet_name"
},
"link": null
},
{
"name": "clip_name",
"type": "COMBO",
"widget": {
"name": "clip_name"
},
"link": null
},
{
"name": "vae_name",
"type": "COMBO",
"widget": {
"name": "vae_name"
},
"link": null
} }
], ],
"outputs": [ "outputs": [
@ -66,28 +97,41 @@
"links": [] "links": []
} }
], ],
"title": "Image to Layers (Qwen-Image-Layered)",
"properties": { "properties": {
"proxyWidgets": [ "proxyWidgets": [
[ [
"-1", "6",
"text" "text"
], ],
[ [
"-1", "3",
"steps" "steps"
], ],
[ [
"-1", "3",
"cfg" "cfg"
], ],
[ [
"-1", "83",
"layers" "layers"
], ],
[ [
"3", "3",
"seed" "seed"
], ],
[
"37",
"unet_name"
],
[
"38",
"clip_name"
],
[
"39",
"vae_name"
],
[ [
"3", "3",
"control_after_generate" "control_after_generate"
@ -95,6 +139,11 @@
], ],
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -103,25 +152,20 @@
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, },
"widgets_values": [ "widgets_values": []
"",
20,
2.5,
2
]
} }
], ],
"links": [], "links": [],
"groups": [], "version": 0.4,
"definitions": { "definitions": {
"subgraphs": [ "subgraphs": [
{ {
"id": "f754a936-daaf-4b6e-9658-41fdc54d301d", "id": "2d2e3c8e-53b3-4618-be52-6d1d99382f0e",
"version": 1, "version": 1,
"state": { "state": {
"lastGroupId": 3, "lastGroupId": 8,
"lastNodeId": 83, "lastNodeId": 176,
"lastLinkId": 159, "lastLinkId": 380,
"lastRerouteId": 0 "lastRerouteId": 0
}, },
"revision": 0, "revision": 0,
@ -130,10 +174,10 @@
"inputNode": { "inputNode": {
"id": -10, "id": -10,
"bounding": [ "bounding": [
-510, -720,
523, 720,
120, 120,
140 220
] ]
}, },
"outputNode": { "outputNode": {
@ -156,8 +200,8 @@
], ],
"localized_name": "image", "localized_name": "image",
"pos": [ "pos": [
-410, -620,
543 740
] ]
}, },
{ {
@ -168,8 +212,8 @@
150 150
], ],
"pos": [ "pos": [
-410, -620,
563 760
] ]
}, },
{ {
@ -180,8 +224,8 @@
153 153
], ],
"pos": [ "pos": [
-410, -620,
583 780
] ]
}, },
{ {
@ -192,8 +236,8 @@
154 154
], ],
"pos": [ "pos": [
-410, -620,
603 800
] ]
}, },
{ {
@ -204,8 +248,56 @@
159 159
], ],
"pos": [ "pos": [
-410, -620,
623 820
]
},
{
"id": "9f76338b-f4ca-4bb3-b61a-57b3f233061e",
"name": "seed",
"type": "INT",
"linkIds": [
377
],
"pos": [
-620,
840
]
},
{
"id": "8d0422d5-5eee-4f7e-9817-dc613cc62eca",
"name": "unet_name",
"type": "COMBO",
"linkIds": [
378
],
"pos": [
-620,
860
]
},
{
"id": "552eece2-a735-4d00-ae78-ded454622bc1",
"name": "clip_name",
"type": "COMBO",
"linkIds": [
379
],
"pos": [
-620,
880
]
},
{
"id": "1e6d141c-d0f9-4a2b-895c-b6780e57cfa0",
"name": "vae_name",
"type": "COMBO",
"linkIds": [
380
],
"pos": [
-620,
900
] ]
} }
], ],
@ -231,14 +323,14 @@
"type": "CLIPLoader", "type": "CLIPLoader",
"pos": [ "pos": [
-320, -320,
310 360
], ],
"size": [ "size": [
346.7470703125, 350,
106 150
], ],
"flags": {}, "flags": {},
"order": 0, "order": 5,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -248,7 +340,7 @@
"widget": { "widget": {
"name": "clip_name" "name": "clip_name"
}, },
"link": null "link": 379
}, },
{ {
"localized_name": "type", "localized_name": "type",
@ -283,9 +375,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "CLIPLoader",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "CLIPLoader",
"models": [ "models": [
{ {
"name": "qwen_2.5_vl_7b_fp8_scaled.safetensors", "name": "qwen_2.5_vl_7b_fp8_scaled.safetensors",
@ -312,14 +409,14 @@
"type": "VAELoader", "type": "VAELoader",
"pos": [ "pos": [
-320, -320,
460 580
], ],
"size": [ "size": [
346.7470703125, 350,
58 110
], ],
"flags": {}, "flags": {},
"order": 1, "order": 6,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -329,7 +426,7 @@
"widget": { "widget": {
"name": "vae_name" "name": "vae_name"
}, },
"link": null "link": 380
} }
], ],
"outputs": [ "outputs": [
@ -345,9 +442,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "VAELoader",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "VAELoader",
"models": [ "models": [
{ {
"name": "qwen_image_layered_vae.safetensors", "name": "qwen_image_layered_vae.safetensors",
@ -375,11 +477,11 @@
420 420
], ],
"size": [ "size": [
425.27801513671875, 430,
180.6060791015625 190
], ],
"flags": {}, "flags": {},
"order": 3, "order": 2,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -411,9 +513,14 @@
], ],
"title": "CLIP Text Encode (Negative Prompt)", "title": "CLIP Text Encode (Negative Prompt)",
"properties": { "properties": {
"Node name for S&R": "CLIPTextEncode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "CLIPTextEncode",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -432,12 +539,12 @@
"id": 70, "id": 70,
"type": "ReferenceLatent", "type": "ReferenceLatent",
"pos": [ "pos": [
330, 140,
670 700
], ],
"size": [ "size": [
204.1666717529297, 210,
46 50
], ],
"flags": { "flags": {
"collapsed": true "collapsed": true
@ -470,9 +577,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "ReferenceLatent",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "ReferenceLatent",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -480,19 +592,18 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 69, "id": 69,
"type": "ReferenceLatent", "type": "ReferenceLatent",
"pos": [ "pos": [
330, 160,
710 820
], ],
"size": [ "size": [
204.1666717529297, 210,
46 50
], ],
"flags": { "flags": {
"collapsed": true "collapsed": true
@ -525,9 +636,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "ReferenceLatent",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "ReferenceLatent",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -535,8 +651,7 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 66, "id": 66,
@ -547,10 +662,10 @@
], ],
"size": [ "size": [
270, 270,
58 110
], ],
"flags": {}, "flags": {},
"order": 4, "order": 7,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -580,9 +695,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "ModelSamplingAuraFlow",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "ModelSamplingAuraFlow",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -600,11 +720,11 @@
"type": "LatentCutToBatch", "type": "LatentCutToBatch",
"pos": [ "pos": [
830, 830,
160 140
], ],
"size": [ "size": [
270, 270,
82 140
], ],
"flags": {}, "flags": {},
"order": 11, "order": 11,
@ -646,9 +766,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "LatentCutToBatch",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "LatentCutToBatch",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -666,12 +791,12 @@
"id": 71, "id": 71,
"type": "VAEEncode", "type": "VAEEncode",
"pos": [ "pos": [
100, -280,
690 780
], ],
"size": [ "size": [
140, 230,
46 100
], ],
"flags": { "flags": {
"collapsed": false "collapsed": false
@ -704,9 +829,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "VAEEncode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "VAEEncode",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -714,24 +844,23 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 8, "id": 8,
"type": "VAEDecode", "type": "VAEDecode",
"pos": [ "pos": [
850, 850,
310 370
], ],
"size": [ "size": [
210, 210,
46 50
], ],
"flags": { "flags": {
"collapsed": true "collapsed": true
}, },
"order": 7, "order": 3,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -759,9 +888,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "VAEDecode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "VAEDecode",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -769,8 +903,7 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 6, "id": 6,
@ -780,11 +913,11 @@
180 180
], ],
"size": [ "size": [
422.84503173828125, 430,
164.31304931640625 170
], ],
"flags": {}, "flags": {},
"order": 6, "order": 1,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -816,9 +949,14 @@
], ],
"title": "CLIP Text Encode (Positive Prompt)", "title": "CLIP Text Encode (Positive Prompt)",
"properties": { "properties": {
"Node name for S&R": "CLIPTextEncode",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "CLIPTextEncode",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -838,14 +976,14 @@
"type": "KSampler", "type": "KSampler",
"pos": [ "pos": [
530, 530,
280 340
], ],
"size": [ "size": [
270, 270,
400 400
], ],
"flags": {}, "flags": {},
"order": 5, "order": 0,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -879,7 +1017,7 @@
"widget": { "widget": {
"name": "seed" "name": "seed"
}, },
"link": null "link": 377
}, },
{ {
"localized_name": "steps", "localized_name": "steps",
@ -939,9 +1077,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "KSampler",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "KSampler",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -964,12 +1107,12 @@
"id": 78, "id": 78,
"type": "GetImageSize", "type": "GetImageSize",
"pos": [ "pos": [
80, -280,
790 930
], ],
"size": [ "size": [
210, 230,
136 140
], ],
"flags": {}, "flags": {},
"order": 12, "order": 12,
@ -1007,9 +1150,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "GetImageSize",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "GetImageSize",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -1017,23 +1165,23 @@
"secondTabText": "Send Back", "secondTabText": "Send Back",
"secondTabOffset": 80, "secondTabOffset": 80,
"secondTabWidth": 65 "secondTabWidth": 65
}, }
"widgets_values": []
}, },
{ {
"id": 83, "id": 83,
"type": "EmptyQwenImageLayeredLatentImage", "type": "EmptyQwenImageLayeredLatentImage",
"pos": [ "pos": [
320, -280,
790 1120
], ],
"size": [ "size": [
330.9341796875, 340,
130 200
], ],
"flags": {}, "flags": {},
"order": 13, "order": 13,
"mode": 0, "mode": 0,
"showAdvanced": true,
"inputs": [ "inputs": [
{ {
"localized_name": "width", "localized_name": "width",
@ -1083,9 +1231,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "EmptyQwenImageLayeredLatentImage",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "EmptyQwenImageLayeredLatentImage",
"enableTabs": false, "enableTabs": false,
"tabWidth": 65, "tabWidth": 65,
"tabXOffset": 10, "tabXOffset": 10,
@ -1109,11 +1262,11 @@
180 180
], ],
"size": [ "size": [
346.7470703125, 350,
82 110
], ],
"flags": {}, "flags": {},
"order": 2, "order": 4,
"mode": 0, "mode": 0,
"inputs": [ "inputs": [
{ {
@ -1123,7 +1276,7 @@
"widget": { "widget": {
"name": "unet_name" "name": "unet_name"
}, },
"link": null "link": 378
}, },
{ {
"localized_name": "weight_dtype", "localized_name": "weight_dtype",
@ -1147,9 +1300,14 @@
} }
], ],
"properties": { "properties": {
"Node name for S&R": "UNETLoader",
"cnr_id": "comfy-core", "cnr_id": "comfy-core",
"ver": "0.5.1", "ver": "0.5.1",
"ue_properties": {
"widget_ue_connectable": {},
"input_ue_unconnectable": {},
"version": "7.7"
},
"Node name for S&R": "UNETLoader",
"models": [ "models": [
{ {
"name": "qwen_image_layered_bf16.safetensors", "name": "qwen_image_layered_bf16.safetensors",
@ -1191,8 +1349,8 @@
"bounding": [ "bounding": [
-330, -330,
110, 110,
366.7470703125, 370,
421.6 610
], ],
"color": "#3f789e", "color": "#3f789e",
"font_size": 24, "font_size": 24,
@ -1391,6 +1549,38 @@
"target_id": 83, "target_id": 83,
"target_slot": 2, "target_slot": 2,
"type": "INT" "type": "INT"
},
{
"id": 377,
"origin_id": -10,
"origin_slot": 5,
"target_id": 3,
"target_slot": 4,
"type": "INT"
},
{
"id": 378,
"origin_id": -10,
"origin_slot": 6,
"target_id": 37,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 379,
"origin_id": -10,
"origin_slot": 7,
"target_id": 38,
"target_slot": 0,
"type": "COMBO"
},
{
"id": 380,
"origin_id": -10,
"origin_slot": 8,
"target_id": 39,
"target_slot": 0,
"type": "COMBO"
} }
], ],
"extra": { "extra": {
@ -1400,7 +1590,6 @@
} }
] ]
}, },
"config": {},
"extra": { "extra": {
"ds": { "ds": {
"scale": 1.14, "scale": 1.14,
@ -1409,7 +1598,6 @@
6.855893974423647 6.855893974423647
] ]
}, },
"workflowRendererVersion": "LG" "ue_links": []
}, }
"version": 0.4
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -267,7 +267,7 @@
"Node name for S&R": "GLSLShader" "Node name for S&R": "GLSLShader"
}, },
"widgets_values": [ "widgets_values": [
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // strength [0.0 2.0] typical: 0.31.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}",
"from_input" "from_input"
] ]
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -383,7 +383,7 @@
"Node name for S&R": "GLSLShader" "Node name for S&R": "GLSLShader"
}, },
"widgets_values": [ "widgets_values": [
"#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n", "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n",
"from_input" "from_input"
] ]
} }

View File

@ -224,6 +224,7 @@ class Flux2(LatentFormat):
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2) self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
self.taesd_decoder_name = "taef2_decoder"
def process_in(self, latent): def process_in(self, latent):
return latent return latent
@ -783,3 +784,10 @@ class ZImagePixelSpace(ChromaRadiance):
No VAE encoding/decoding the model operates directly on RGB pixels. No VAE encoding/decoding the model operates directly on RGB pixels.
""" """
pass pass
class CogVideoX(LatentFormat):
latent_channels = 16
latent_dimensions = 3
def __init__(self):
self.scale_factor = 1.15258426

View File

573
comfy/ldm/cogvideo/model.py Normal file
View File

@ -0,0 +1,573 @@
# CogVideoX 3D Transformer - ported to ComfyUI native ops
# Architecture reference: diffusers CogVideoXTransformer3DModel
# Style reference: comfy/ldm/wan/model.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.patcher_extension
import comfy.ldm.common_dit
def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0):
"""Returns (cos, sin) each with shape [seq_len, dim].
Frequencies are computed at dim//2 resolution then repeat_interleaved
to full dim, matching CogVideoX's interleaved (real, imag) pair format.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim))
angles = torch.outer(pos.float(), freqs.float())
cos = angles.cos().repeat_interleave(2, dim=-1).float()
sin = angles.sin().repeat_interleave(2, dim=-1).float()
return (cos, sin)
def apply_rotary_emb(x, freqs_cos_sin):
"""Apply CogVideoX rotary embedding to query or key tensor.
x: [B, heads, seq_len, head_dim]
freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2]
Uses interleaved pair rotation (same as diffusers CogVideoX/Flux).
head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back.
"""
cos, sin = freqs_cos_sin
cos = cos[None, None, :, :].to(x.device)
sin = sin[None, None, :, :].to(x.device)
# Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000):
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half)
args = timesteps[:, None].float() * freqs[None] * scale
embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
if flip_sin_to_cos:
embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None):
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale
grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale
grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale
grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij")
embed_dim_spatial = 2 * (embed_dim // 3)
embed_dim_temporal = embed_dim // 3
pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device)
pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device)
T, H, W = grid_t.shape
pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1)
pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1)
return pos_embed
def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None):
T, H, W = grid_h.shape
half_dim = embed_dim // 2
pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim)
pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim)
return torch.cat([pos_h, pos_w], dim=-1)
def _get_1d_sincos_pos_embed(embed_dim, pos, device=None):
half = embed_dim // 2
freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half)
args = pos.float().reshape(-1)[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if embed_dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class CogVideoXPatchEmbed(nn.Module):
def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920,
text_dim=4096, bias=True, sample_width=90, sample_height=60,
sample_frames=49, temporal_compression_ratio=4,
max_text_seq_length=226, spatial_interpolation_scale=1.875,
temporal_interpolation_scale=1.0, use_positional_embeddings=True,
use_learned_positional_embeddings=True,
device=None, dtype=None, operations=None):
super().__init__()
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.dim = dim
self.sample_height = sample_height
self.sample_width = sample_width
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings
self.use_learned_positional_embeddings = use_learned_positional_embeddings
if patch_size_t is None:
self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype)
else:
self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype)
self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype)
if use_positional_embeddings or use_learned_positional_embeddings:
persistent = use_learned_positional_embeddings
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None):
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
if self.patch_size_t is not None:
post_time_compression_frames = post_time_compression_frames // self.patch_size_t
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
pos_embedding = get_3d_sincos_pos_embed(
self.dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=device,
)
pos_embedding = pos_embedding.reshape(-1, self.dim)
joint_pos_embedding = pos_embedding.new_zeros(
1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
return joint_pos_embedding
def forward(self, text_embeds, image_embeds):
input_dtype = text_embeds.dtype
text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype)
batch_size, num_frames, channels, height, width = image_embeds.shape
proj_dtype = self.proj.weight.dtype
if self.patch_size_t is None:
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3)
image_embeds = image_embeds.flatten(1, 2)
else:
p = self.patch_size
p_t = self.patch_size_t
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
image_embeds = image_embeds.reshape(
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
)
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype)
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
text_seq_length = text_embeds.shape[1]
num_image_patches = image_embeds.shape[1]
if self.use_learned_positional_embeddings:
image_pos = self.pos_embedding[
:, self.max_text_seq_length:self.max_text_seq_length + num_image_patches
].to(device=embeds.device, dtype=embeds.dtype)
else:
image_pos = get_3d_sincos_pos_embed(
self.dim,
(width // self.patch_size, height // self.patch_size),
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
device=embeds.device,
).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype)
# Build joint: zeros for text + sincos for image
joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
joint_pos[:, text_seq_length:] = image_pos
embeds = embeds + joint_pos
return embeds
class CogVideoXLayerNormZero(nn.Module):
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True,
device=None, dtype=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype)
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
def forward(self, hidden_states, encoder_hidden_states, temb):
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
class CogVideoXAdaLayerNorm(nn.Module):
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5,
device=None, dtype=None, operations=None):
super().__init__()
self.silu = nn.SiLU()
self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype)
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
def forward(self, x, temb):
temb = self.linear(self.silu(temb))
shift, scale = temb.chunk(2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class CogVideoXBlock(nn.Module):
def __init__(self, dim, num_heads, head_dim, time_dim,
eps=1e-5, ff_inner_dim=None, ff_bias=True,
device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = head_dim
self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
# Self-attention (joint text + latent)
self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
# Feed-forward (GELU approximate)
inner_dim = ff_inner_dim or dim * 4
self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype)
self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype)
def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options=None):
if transformer_options is None:
transformer_options = {}
text_seq_length = encoder_hidden_states.size(1)
# Norm & modulate
norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb)
# Joint self-attention
qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1)
b, s, _ = qkv_input.shape
n, d = self.num_heads, self.head_dim
q = self.q(qkv_input).view(b, s, n, d)
k = self.k(qkv_input).view(b, s, n, d)
v = self.v(qkv_input)
q = self.norm_q(q).view(b, s, n, d)
k = self.norm_k(k).view(b, s, n, d)
# Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim])
if image_rotary_emb is not None:
q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim]
k_img = k[:, text_seq_length:].transpose(1, 2)
q_img = apply_rotary_emb(q_img, image_rotary_emb)
k_img = apply_rotary_emb(k_img, image_rotary_emb)
q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1)
k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1)
attn_out = optimized_attention(
q.reshape(b, s, n * d),
k.reshape(b, s, n * d),
v,
heads=self.num_heads,
transformer_options=transformer_options,
)
attn_out = self.attn_out(attn_out)
attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1)
hidden_states = hidden_states + gate_msa * attn_hidden
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder
# Norm & modulate for FF
norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb)
# Feed-forward (GELU on concatenated text + latent)
ff_input = torch.cat([norm_encoder, norm_hidden], dim=1)
ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh"))
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(nn.Module):
def __init__(self,
num_attention_heads=30,
attention_head_dim=64,
in_channels=16,
out_channels=16,
flip_sin_to_cos=True,
freq_shift=0,
time_embed_dim=512,
ofs_embed_dim=None,
text_embed_dim=4096,
num_layers=30,
dropout=0.0,
attention_bias=True,
sample_width=90,
sample_height=60,
sample_frames=49,
patch_size=2,
patch_size_t=None,
temporal_compression_ratio=4,
max_text_seq_length=226,
spatial_interpolation_scale=1.875,
temporal_interpolation_scale=1.0,
use_rotary_positional_embeddings=False,
use_learned_positional_embeddings=False,
patch_bias=True,
image_model=None,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.dtype = dtype
dim = num_attention_heads * attention_head_dim
self.dim = dim
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.in_channels = in_channels
self.out_channels = out_channels
self.patch_size = patch_size
self.patch_size_t = patch_size_t
self.max_text_seq_length = max_text_seq_length
self.use_rotary_positional_embeddings = use_rotary_positional_embeddings
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(
patch_size=patch_size,
patch_size_t=patch_size_t,
in_channels=in_channels,
dim=dim,
text_dim=text_embed_dim,
bias=patch_bias,
sample_width=sample_width,
sample_height=sample_height,
sample_frames=sample_frames,
temporal_compression_ratio=temporal_compression_ratio,
max_text_seq_length=max_text_seq_length,
spatial_interpolation_scale=spatial_interpolation_scale,
temporal_interpolation_scale=temporal_interpolation_scale,
use_positional_embeddings=not use_rotary_positional_embeddings,
use_learned_positional_embeddings=use_learned_positional_embeddings,
device=device, dtype=torch.float32, operations=operations,
)
# 2. Time embedding
self.time_proj_dim = dim
self.time_proj_flip = flip_sin_to_cos
self.time_proj_shift = freq_shift
self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype)
self.time_embedding_act = nn.SiLU()
self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype)
# Optional OFS embedding (CogVideoX 1.5 I2V)
self.ofs_proj_dim = ofs_embed_dim
if ofs_embed_dim:
self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
self.ofs_embedding_act = nn.SiLU()
self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
else:
self.ofs_embedding_linear_1 = None
# 3. Transformer blocks
self.blocks = nn.ModuleList([
CogVideoXBlock(
dim=dim,
num_heads=num_attention_heads,
head_dim=attention_head_dim,
time_dim=time_embed_dim,
eps=1e-5,
device=device, dtype=dtype, operations=operations,
)
for _ in range(num_layers)
])
self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype)
# 4. Output
self.norm_out = CogVideoXAdaLayerNorm(
time_dim=time_embed_dim, dim=dim, eps=1e-5,
device=device, dtype=dtype, operations=operations,
)
if patch_size_t is None:
output_dim = patch_size * patch_size * out_channels
else:
output_dim = patch_size * patch_size * patch_size_t * out_channels
self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype)
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.temporal_compression_ratio = temporal_compression_ratio
def forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, ofs, transformer_options, **kwargs)
def _forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs):
if transformer_options is None:
transformer_options = {}
# ComfyUI passes [B, C, T, H, W]
batch_size, channels, t, h, w = x.shape
# Pad to patch size (temporal + spatial), same pattern as WAN
p_t = self.patch_size_t if self.patch_size_t is not None else 1
x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size))
# CogVideoX expects [B, T, C, H, W]
x = x.permute(0, 2, 1, 3, 4)
batch_size, num_frames, channels, height, width = x.shape
# Time embedding
t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift)
t_emb = t_emb.to(dtype=x.dtype)
emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb)))
if self.ofs_embedding_linear_1 is not None and ofs is not None:
ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift)
ofs_emb = ofs_emb.to(dtype=x.dtype)
ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb)))
emb = emb + ofs_emb
# Patch embedding
hidden_states = self.patch_embed(context, x)
text_seq_length = context.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
# Rotary embeddings (if used)
image_rotary_emb = None
if self.use_rotary_positional_embeddings:
post_patch_height = height // self.patch_size
post_patch_width = width // self.patch_size
if self.patch_size_t is None:
post_time = num_frames
else:
post_time = num_frames // self.patch_size_t
image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device)
# Transformer blocks
for i, block in enumerate(self.blocks):
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options,
)
hidden_states = self.norm_final(hidden_states)
# Output projection
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# Unpatchify
p = self.patch_size
p_t = self.patch_size_t
if p_t is None:
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
else:
output = hidden_states.reshape(
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
)
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
# Back to ComfyUI format [B, C, T, H, W] and crop padding
output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w]
return output
def _get_rotary_emb(self, h, w, t, device):
"""Compute CogVideoX 3D rotary positional embeddings.
For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode grid positions
are integer arange computed at max_size, then sliced to actual size.
For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords
scaled by spatial_interpolation_scale.
"""
d = self.attention_head_dim
dim_t = d // 4
dim_h = d // 8 * 3
dim_w = d // 8 * 3
if self.patch_size_t is not None:
# CogVideoX 1.5: "slice" mode — positions are simple integer indices
# Compute at max(sample_size, actual_size) then slice to actual
base_h = self.patch_embed.sample_height // self.patch_size
base_w = self.patch_embed.sample_width // self.patch_size
max_h = max(base_h, h)
max_w = max(base_w, w)
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
grid_t = torch.arange(t, device=device, dtype=torch.float32)
else:
# CogVideoX 1.0: "linspace" mode with interpolation scale
grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
grid_t = torch.arange(t, device=device, dtype=torch.float32)
freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t)
freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h)
freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w)
t_cos, t_sin = freqs_t
h_cos, h_sin = freqs_h
w_cos, w_sin = freqs_w
# Slice to actual size (for "slice" mode where grids may be larger)
t_cos, t_sin = t_cos[:t], t_sin[:t]
h_cos, h_sin = h_cos[:h], h_sin[:h]
w_cos, w_sin = w_cos[:w], w_sin[:w]
# Broadcast and concatenate into [T*H*W, head_dim]
t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1)
t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1)
h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1)
h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1)
w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1)
w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1)
cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1)
sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1)
return (cos, sin)

566
comfy/ldm/cogvideo/vae.py Normal file
View File

@ -0,0 +1,566 @@
# CogVideoX VAE - ported to ComfyUI native ops
# Architecture reference: diffusers AutoencoderKLCogVideoX
# Style reference: comfy/ldm/wan/vae.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
"""Causal 3D convolution with temporal padding.
Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has
a single temporal frame and no cache, the 3D conv weight is sliced to act
as a 2D conv, avoiding computation on zero-padded temporal dimensions.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
time_kernel, height_kernel, width_kernel = kernel_size
self.time_kernel_size = time_kernel
self.pad_mode = pad_mode
height_pad = (height_kernel - 1) // 2
width_pad = (width_kernel - 1) // 2
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0)
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
dilation = (dilation, 1, 1)
self.conv = ops.Conv3d(
in_channels, out_channels, kernel_size,
stride=stride, dilation=dilation,
padding=(0, height_pad, width_pad),
)
def forward(self, x, conv_cache=None):
if self.pad_mode == "replicate":
x = F.pad(x, self.time_causal_padding, mode="replicate")
conv_cache = None
else:
kernel_t = self.time_kernel_size
if kernel_t > 1:
if conv_cache is None and x.shape[2] == 1:
# Fast path: single frame, no cache. All temporal padding
# frames are copies of the input (replicate-style), so the
# 3D conv reduces to a 2D conv with summed temporal kernel.
w = comfy.ops.cast_to_input(self.conv.weight, x)
b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None
w2d = w.sum(dim=2, keepdim=True)
out = F.conv3d(x, w2d, b,
self.conv.stride, self.conv.padding,
self.conv.dilation, self.conv.groups)
return out, None
cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1)
x = torch.cat(cached + [x], dim=2)
conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None
out = self.conv(x)
return out, conv_cache
def _interpolate_zq(zq, target_size):
"""Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling."""
t = target_size[0]
if t > 1 and t % 2 == 1:
z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2]))
z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2]))
return torch.cat([z_first, z_rest], dim=2)
return F.interpolate(zq, size=target_size)
class SpatialNorm3D(nn.Module):
"""Spatially conditioned normalization."""
def __init__(self, f_channels, zq_channels, groups=32):
super().__init__()
self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
def forward(self, f, zq, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
if zq.shape[-3:] != f.shape[-3:]:
zq = _interpolate_zq(zq, f.shape[-3:])
conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
return self.norm_layer(f) * conv_y + conv_b, new_cache
class ResnetBlock3D(nn.Module):
"""3D ResNet block with optional spatial norm."""
def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32,
eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.spatial_norm_dim = spatial_norm_dim
if act_fn == "silu":
self.nonlinearity = nn.SiLU()
elif act_fn == "swish":
self.nonlinearity = nn.SiLU()
else:
self.nonlinearity = nn.SiLU()
if spatial_norm_dim is None:
self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
else:
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if temb_channels > 0:
self.temb_proj = ops.Linear(temb_channels, out_channels)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
if in_channels != out_channels:
self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
def forward(self, x, temb=None, zq=None, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
residual = x
if zq is not None:
x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1"))
else:
x = self.norm1(x)
x = self.nonlinearity(x)
x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1"))
if temb is not None and hasattr(self, "temb_proj"):
x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is not None:
x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2"))
else:
x = self.norm2(x)
x = self.nonlinearity(x)
x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2"))
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
return x + residual, new_cache
class Downsample3D(nn.Module):
"""3D downsampling with optional temporal compression."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
super().__init__()
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
b, c, t, h, w = x.shape
x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
if t % 2 == 1:
x_first, x_rest = x[..., 0], x[..., 1:]
if x_rest.shape[-1] > 0:
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
x = torch.cat([x_first[..., None], x_rest], dim=-1)
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
else:
x = F.avg_pool1d(x, kernel_size=2, stride=2)
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
b, c, t, h, w = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.conv(x)
x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
return x
class Upsample3D(nn.Module):
"""3D upsampling with optional temporal decompression."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
super().__init__()
self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
self.compress_time = compress_time
def forward(self, x):
if self.compress_time:
if x.shape[2] > 1 and x.shape[2] % 2 == 1:
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
x_first = F.interpolate(x_first, scale_factor=2.0)
x_rest = F.interpolate(x_rest, scale_factor=2.0)
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
elif x.shape[2] > 1:
x = F.interpolate(x, scale_factor=2.0)
else:
x = x.squeeze(2)
x = F.interpolate(x, scale_factor=2.0)
x = x[:, :, None, :, :]
else:
b, c, t, h, w = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = F.interpolate(x, scale_factor=2.0)
x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4)
b, c, t, h, w = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
x = self.conv(x)
x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4)
return x
class DownBlock3D(nn.Module):
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
eps=1e-6, act_fn="silu", groups=32, add_downsample=True,
compress_time=False, pad_mode="first"):
super().__init__()
self.resnets = nn.ModuleList([
ResnetBlock3D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode,
)
for i in range(num_layers)
])
self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None
def forward(self, x, temb=None, zq=None, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
if self.downsamplers is not None:
for ds in self.downsamplers:
x = ds(x)
return x, new_cache
class MidBlock3D(nn.Module):
def __init__(self, in_channels, temb_channels=0, num_layers=1,
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"):
super().__init__()
self.resnets = nn.ModuleList([
ResnetBlock3D(
in_channels=in_channels, out_channels=in_channels,
temb_channels=temb_channels, groups=groups, eps=eps,
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
)
for _ in range(num_layers)
])
def forward(self, x, temb=None, zq=None, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
return x, new_cache
class UpBlock3D(nn.Module):
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16,
add_upsample=True, compress_time=False, pad_mode="first"):
super().__init__()
self.resnets = nn.ModuleList([
ResnetBlock3D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
temb_channels=temb_channels, groups=groups, eps=eps,
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
)
for i in range(num_layers)
])
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None
def forward(self, x, temb=None, zq=None, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
for i, resnet in enumerate(self.resnets):
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
if self.upsamplers is not None:
for us in self.upsamplers:
x = us(x)
return x, new_cache
class Encoder3D(nn.Module):
def __init__(self, in_channels=3, out_channels=16,
block_out_channels=(128, 256, 256, 512),
layers_per_block=3, act_fn="silu",
eps=1e-6, groups=32, pad_mode="first",
temporal_compression_ratio=4):
super().__init__()
temporal_compress_level = int(np.log2(temporal_compression_ratio))
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
self.down_blocks = nn.ModuleList()
output_channel = block_out_channels[0]
for i in range(len(block_out_channels)):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final = i == len(block_out_channels) - 1
compress_time = i < temporal_compress_level
self.down_blocks.append(DownBlock3D(
in_channels=input_channel, out_channels=output_channel,
temb_channels=0, num_layers=layers_per_block,
eps=eps, act_fn=act_fn, groups=groups,
add_downsample=not is_final, compress_time=compress_time,
))
self.mid_block = MidBlock3D(
in_channels=block_out_channels[-1], temb_channels=0,
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
)
self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
def forward(self, x, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in"))
for i, block in enumerate(self.down_blocks):
key = f"down_block_{i}"
x, new_cache[key] = block(x, None, None, conv_cache.get(key))
x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block"))
x = self.norm_out(x)
x = self.conv_act(x)
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
return x, new_cache
class Decoder3D(nn.Module):
def __init__(self, in_channels=16, out_channels=3,
block_out_channels=(128, 256, 256, 512),
layers_per_block=3, act_fn="silu",
eps=1e-6, groups=32, pad_mode="first",
temporal_compression_ratio=4):
super().__init__()
reversed_channels = list(reversed(block_out_channels))
temporal_compress_level = int(np.log2(temporal_compression_ratio))
self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode)
self.mid_block = MidBlock3D(
in_channels=reversed_channels[0], temb_channels=0,
num_layers=2, eps=eps, act_fn=act_fn, groups=groups,
spatial_norm_dim=in_channels, pad_mode=pad_mode,
)
self.up_blocks = nn.ModuleList()
output_channel = reversed_channels[0]
for i in range(len(block_out_channels)):
prev_channel = output_channel
output_channel = reversed_channels[i]
is_final = i == len(block_out_channels) - 1
compress_time = i < temporal_compress_level
self.up_blocks.append(UpBlock3D(
in_channels=prev_channel, out_channels=output_channel,
temb_channels=0, num_layers=layers_per_block + 1,
eps=eps, act_fn=act_fn, groups=groups,
spatial_norm_dim=in_channels,
add_upsample=not is_final, compress_time=compress_time,
))
self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups)
self.conv_act = nn.SiLU()
self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode)
def forward(self, sample, conv_cache=None):
new_cache = {}
conv_cache = conv_cache or {}
x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block"))
for i, block in enumerate(self.up_blocks):
key = f"up_block_{i}"
x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key))
x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out"))
x = self.conv_act(x)
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
return x, new_cache
class AutoencoderKLCogVideoX(nn.Module):
"""CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper.
Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run
on the full (low-res) tensor, then the expensive spatial-only up_blocks +
norm_out + conv_out are processed in small temporal chunks with conv_cache
carrying causal state between chunks. This keeps peak VRAM proportional to
chunk_size rather than total frame count.
"""
def __init__(self,
in_channels=3, out_channels=3,
block_out_channels=(128, 256, 256, 512),
latent_channels=16, layers_per_block=3,
act_fn="silu", eps=1e-6, groups=32,
temporal_compression_ratio=4,
):
super().__init__()
self.latent_channels = latent_channels
self.temporal_compression_ratio = temporal_compression_ratio
self.encoder = Encoder3D(
in_channels=in_channels, out_channels=latent_channels,
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
act_fn=act_fn, eps=eps, groups=groups,
temporal_compression_ratio=temporal_compression_ratio,
)
self.decoder = Decoder3D(
in_channels=latent_channels, out_channels=out_channels,
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
act_fn=act_fn, eps=eps, groups=groups,
temporal_compression_ratio=temporal_compression_ratio,
)
self.num_latent_frames_batch_size = 2
self.num_sample_frames_batch_size = 8
def encode(self, x):
t = x.shape[2]
frame_batch = self.num_sample_frames_batch_size
remainder = t % frame_batch
conv_cache = None
enc = []
# Process remainder frames first so only the first chunk can have an
# odd temporal dimension — where Downsample3D's first-frame-special
# handling in temporal compression is actually correct.
if remainder > 0:
chunk, conv_cache = self.encoder(x[:, :, :remainder], conv_cache=conv_cache)
enc.append(chunk.to(x.device))
for start in range(remainder, t, frame_batch):
chunk, conv_cache = self.encoder(x[:, :, start:start + frame_batch], conv_cache=conv_cache)
enc.append(chunk.to(x.device))
enc = torch.cat(enc, dim=2)
mean, _ = enc.chunk(2, dim=1)
return mean
def decode(self, z):
return self._decode_rolling(z)
def _decode_batched(self, z):
"""Original batched decode - processes 2 latent frames through full decoder."""
t = z.shape[2]
frame_batch = self.num_latent_frames_batch_size
num_batches = max(t // frame_batch, 1)
conv_cache = None
dec = []
for i in range(num_batches):
remaining = t % frame_batch
start = frame_batch * i + (0 if i == 0 else remaining)
end = frame_batch * (i + 1) + remaining
chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache)
dec.append(chunk.cpu())
return torch.cat(dec, dim=2).to(z.device)
def _decode_rolling(self, z):
"""Rolling decode - processes low-res layers on full tensor, then rolls
through expensive high-res layers in temporal chunks."""
decoder = self.decoder
device = z.device
# Determine which up_blocks have temporal upsample vs spatial-only.
# Temporal up_blocks are cheap (low res), spatial-only are expensive.
temporal_compress_level = int(np.log2(self.temporal_compression_ratio))
split_at = temporal_compress_level # first N up_blocks do temporal upsample
# Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res)
x, _ = decoder.conv_in(z)
x, _ = decoder.mid_block(x, None, z)
for i in range(split_at):
x, _ = decoder.up_blocks[i](x, None, z)
# Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks
remaining_blocks = list(range(split_at, len(decoder.up_blocks)))
chunk_size = 4 # pixel frames per chunk through high-res layers
t_expanded = x.shape[2]
if t_expanded <= chunk_size or len(remaining_blocks) == 0:
# Small enough to process in one go
for i in remaining_blocks:
x, _ = decoder.up_blocks[i](x, None, z)
x, _ = decoder.norm_out(x, z)
x = decoder.conv_act(x)
x, _ = decoder.conv_out(x)
return x
# Expand z temporally once to match Phase 2's time dimension.
# z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB
# for the old approach of pre-interpolating to every pixel resolution).
z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4]))
# Process in temporal chunks, interpolating spatially per-chunk to avoid
# allocating full [B, C, t_expanded, H, W] tensors at each resolution.
dec_out = []
conv_caches = {}
for chunk_start in range(0, t_expanded, chunk_size):
chunk_end = min(chunk_start + chunk_size, t_expanded)
x_chunk = x[:, :, chunk_start:chunk_end]
z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end]
z_spatial_cache = {}
for i in remaining_blocks:
block = decoder.up_blocks[i]
cache_key = f"up_block_{i}"
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
if hw_key not in z_spatial_cache:
if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]:
z_spatial_cache[hw_key] = z_t_chunk
else:
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key))
conv_caches[cache_key] = new_cache
hw_key = (x_chunk.shape[3], x_chunk.shape[4])
if hw_key not in z_spatial_cache:
z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1]))
x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out"))
conv_caches["norm_out"] = new_cache
x_chunk = decoder.conv_act(x_chunk)
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
conv_caches["conv_out"] = new_cache
dec_out.append(x_chunk.cpu())
del z_spatial_cache
del x, z_time_expanded
return torch.cat(dec_out, dim=2).to(device)

View File

@ -54,7 +54,7 @@ class SplitMHA(nn.Module):
if mask is not None and mask.ndim == 2: if mask is not None and mask.ndim == 2:
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
dtype = q.dtype # manual_cast may produce mixed dtypes dtype = q.dtype # manual_cast may produce mixed dtypes
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask) out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False)
return self.out_proj(out) return self.out_proj(out)

View File

@ -40,7 +40,7 @@ class SAMAttention(nn.Module):
q = self.q_proj(q) q = self.q_proj(q)
k = self.k_proj(k) k = self.k_proj(k)
v = self.v_proj(v) v = self.v_proj(v)
return self.out_proj(optimized_attention(q, k, v, self.num_heads)) return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
class TwoWayAttentionBlock(nn.Module): class TwoWayAttentionBlock(nn.Module):
@ -179,7 +179,7 @@ class Attention(nn.Module):
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
if self.use_rope and freqs_cis is not None: if self.use_rope and freqs_cis is not None:
q, k = apply_rope(q, k, freqs_cis) q, k = apply_rope(q, k, freqs_cis)
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True)) return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
class Block(nn.Module): class Block(nn.Module):

View File

@ -364,7 +364,7 @@ class SplitAttn(nn.Module):
v = self.v_proj(v) v = self.v_proj(v)
if rope is not None: if rope is not None:
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
out = optimized_attention(q, k, v, self.num_heads) out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)
return self.out_proj(out) return self.out_proj(out)
@ -657,7 +657,7 @@ class DecoupledMemoryAttnLayer(nn.Module):
v = self.self_attn_v_proj(normed) v = self.self_attn_v_proj(normed)
if rope is not None: if rope is not None:
q, k = apply_rope_memory(q, k, rope, self.num_heads, 0) q, k = apply_rope_memory(q, k, rope, self.num_heads, 0)
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
# Decoupled cross-attention: fuse image and memory projections # Decoupled cross-attention: fuse image and memory projections
normed = self.norm2(x) normed = self.norm2(x)
@ -668,7 +668,7 @@ class DecoupledMemoryAttnLayer(nn.Module):
v = self.cross_attn_v_proj(memory) v = self.cross_attn_v_proj(memory)
if rope is not None: if rope is not None:
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads)) x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
# FFN # FFN
x = x + self.linear2(F.gelu(self.linear1(self.norm3(x)))) x = x + self.linear2(F.gelu(self.linear1(self.norm3(x))))

View File

@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15 import comfy.ldm.ace.ace_step15
import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector import comfy.ldm.sam3.detector
@ -81,6 +82,7 @@ class ModelType(Enum):
IMG_TO_IMG = 9 IMG_TO_IMG = 9
FLOW_COSMOS = 10 FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11 IMG_TO_IMG_FLOW = 11
V_PREDICTION_DDPM = 12
def model_sampling(model_config, model_type): def model_sampling(model_config, model_type):
@ -115,6 +117,8 @@ def model_sampling(model_config, model_type):
s = comfy.model_sampling.ModelSamplingCosmosRFlow s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW: elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW c = comfy.model_sampling.IMG_TO_IMG_FLOW
elif model_type == ModelType.V_PREDICTION_DDPM:
c = comfy.model_sampling.V_PREDICTION_DDPM
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@ -1979,3 +1983,59 @@ class ErnieImage(BaseModel):
class SAM3(BaseModel): class SAM3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)
class CogVideoX(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel)
self.image_to_video = image_to_video
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
# Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent)
extra_channels = self.diffusion_model.in_channels - noise.shape[1]
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
shape = list(noise.shape)
shape[1] = extra_channels
return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device)
latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if noise.ndim == 5 and image.ndim == 5:
if image.shape[-3] < noise.shape[-3]:
image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0)
elif image.shape[-3] > noise.shape[-3]:
image = image[:, :, :noise.shape[-3]]
for i in range(0, image.shape[1], latent_dim):
image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
if image.shape[1] > extra_channels:
image = image[:, :extra_channels]
elif image.shape[1] < extra_channels:
repeats = extra_channels // image.shape[1]
remainder = extra_channels % image.shape[1]
parts = [image] * repeats
if remainder > 0:
parts.append(image[:, :remainder])
image = torch.cat(parts, dim=1)
return image
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
# OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR
if self.diffusion_model.ofs_proj_dim is not None:
ofs = kwargs.get("ofs", None)
if ofs is None:
noise = kwargs.get("noise", None)
ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
out['ofs'] = comfy.conds.CONDRegular(ofs)
return out

View File

@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config return dit_config
if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX
dit_config = {}
dit_config["image_model"] = "cogvideox"
# Extract config from weight shapes
norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)]
time_embed_dim = norm1_weight.shape[1]
dim = norm1_weight.shape[0] // 6
dit_config["num_attention_heads"] = dim // 64
dit_config["attention_head_dim"] = 64
dit_config["time_embed_dim"] = time_embed_dim
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
# Detect in_channels from patch_embed
patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix)
if patch_proj_key in state_dict_keys:
w = state_dict[patch_proj_key]
if w.ndim == 4:
# Conv2d: [out, in, kh, kw] — CogVideoX 1.0
dit_config["in_channels"] = w.shape[1]
dit_config["patch_size"] = w.shape[2]
elif w.ndim == 2:
# Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5
dit_config["patch_size"] = 2
dit_config["patch_size_t"] = 2
dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32
text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix)
if text_proj_key in state_dict_keys:
dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1]
# Detect OFS embedding
ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix)
if ofs_key in state_dict_keys:
dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1]
# Detect positional embedding type
pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
if pos_key in state_dict_keys:
dit_config["use_learned_positional_embeddings"] = True
dit_config["use_rotary_positional_embeddings"] = False
else:
dit_config["use_learned_positional_embeddings"] = False
dit_config["use_rotary_positional_embeddings"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {} dit_config = {}
dit_config["image_model"] = "wan2.1" dit_config["image_model"] = "wan2.1"

View File

@ -663,6 +663,7 @@ def minimum_inference_memory():
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc() cleanup_models_gc()
comfy.memory_management.extra_ram_release(max(pins_required, ram_required))
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = [] unloaded_models = []

View File

@ -31,6 +31,7 @@ import comfy.float
import comfy.hooks import comfy.hooks
import comfy.lora import comfy.lora
import comfy.model_management import comfy.model_management
import comfy.ops
import comfy.patcher_extension import comfy.patcher_extension
import comfy.utils import comfy.utils
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
@ -856,7 +857,9 @@ class ModelPatcher:
if m.comfy_patched_weights == True: if m.comfy_patched_weights == True:
continue continue
for param in params: for param, param_value in params.items():
if hasattr(m, "comfy_cast_weights") and getattr(param_value, "is_meta", False):
comfy.ops.disable_weight_init._zero_init_parameter(m, param)
key = key_param_name_to_key(n, param) key = key_param_name_to_key(n, param)
self.unpin_weight(key) self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to) self.patch_weight_to_device(key, device_to=device_to)

View File

@ -54,6 +54,30 @@ class V_PREDICTION(EPS):
sigma = reshape_sigma(sigma, model_output.ndim) sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class V_PREDICTION_DDPM:
"""CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v.
x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v
= x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1)
"""
def calculate_input(self, sigma, noise):
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
noise = noise * sigma
noise += latent_image
return noise
def inverse_noise_scaling(self, sigma, latent):
return latent
class EDM(V_PREDICTION): class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input): def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim) sigma = reshape_sigma(sigma, model_output.ndim)

View File

@ -79,14 +79,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): def materialize_meta_param(s, param_keys):
for param_key in param_keys:
param = getattr(s, param_key, None)
if param is not None and getattr(param, "is_meta", False):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths #vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take #that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned #a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU #If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place. #or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device): if comfy.model_management.is_device_cpu(device):
materialize_meta_param(s, ["weight", "bias"])
weight = s.weight.to(dtype=dtype, copy=True) weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor): if isinstance(weight, QuantizedTensor):
weight = weight.dequantize() weight = weight.dequantize()
@ -108,6 +115,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident: if not resident:
materialize_meta_param(s, ["weight", "bias"])
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None cast_dest = None
@ -306,6 +314,12 @@ class CastWeightBiasOp:
bias_function = [] bias_function = []
class disable_weight_init: class disable_weight_init:
@staticmethod
def _zero_init_parameter(module, name):
param = getattr(module, name)
device = None if getattr(param, "is_meta", False) else param.device
setattr(module, name, torch.nn.Parameter(torch.zeros(param.shape, device=device, dtype=param.dtype), requires_grad=False))
@staticmethod @staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata, def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape, missing_keys, unexpected_keys, weight_shape,

View File

@ -2,7 +2,6 @@ import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy_aimdo.host_buffer import comfy_aimdo.host_buffer
import comfy_aimdo.torch import comfy_aimdo.torch
import psutil
from comfy.cli_args import args from comfy.cli_args import args
@ -12,11 +11,6 @@ def get_pin(module):
def pin_memory(module): def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return return
#FIXME: This is a RAM cache trigger event
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
#we split the difference and assume half the RAM cache headroom is for us
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
comfy.memory_management.extra_ram_release(ram_headroom)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])

View File

@ -18,6 +18,7 @@ import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2 import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert import comfy.pixel_space_convert
@ -479,7 +480,10 @@ class VAE:
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config}, encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config}) decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd: elif "taesd_decoder.1.weight" in sd:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1] if isinstance(metadata, dict) and "tae_latent_channels" in metadata:
self.latent_channels = metadata["tae_latent_channels"]
else:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels) self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA() self.first_stage_model = StageA()
@ -653,6 +657,17 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2
self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels)
self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd: elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True ddconfig["conv3d"] = True

View File

@ -27,6 +27,7 @@ import comfy.text_encoders.anima
import comfy.text_encoders.ace15 import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -1832,6 +1833,52 @@ class SAM31(SAM3):
unet_config = {"image_model": "SAM31"} unet_config = {"image_model": "SAM31"}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31] class CogVideoX_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "cogvideox",
}
sampling_settings = {
"linear_start": 0.00085,
"linear_end": 0.012,
"beta_schedule": "linear",
"zsnr": True,
}
unet_extra_config = {}
latent_format = latent_formats.CogVideoX
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
if self.unet_config.get("patch_size_t") is not None:
self.unet_config.setdefault("sample_height", 96)
self.unet_config.setdefault("sample_width", 170)
self.unet_config.setdefault("sample_frames", 81)
out = model_base.CogVideoX(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.cogvideo.CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel)
class CogVideoX_I2V(CogVideoX_T2V):
unet_config = {
"image_model": "cogvideox",
"in_channels": 32,
}
def get_model(self, state_dict, prefix="", device=None):
if self.unet_config.get("patch_size_t") is not None:
self.unet_config.setdefault("sample_height", 96)
self.unet_config.setdefault("sample_width", 170)
self.unet_config.setdefault("sample_frames", 81)
out = model_base.CogVideoX(self, image_to_video=True, device=device)
return out
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31, CogVideoX_I2V, CogVideoX_T2V]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -7,6 +7,7 @@ from tqdm.auto import tqdm
from collections import namedtuple, deque from collections import namedtuple, deque
import comfy.ops import comfy.ops
import comfy.model_management
operations=comfy.ops.disable_weight_init operations=comfy.ops.disable_weight_init
DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
@ -47,11 +48,14 @@ class TGrow(nn.Module):
x = self.conv(x) x = self.conv(x)
return x.reshape(-1, C, H, W) return x.reshape(-1, C, H, W)
def apply_model_with_memblocks(model, x, parallel, show_progress_bar): def apply_model_with_memblocks(model, x, parallel, show_progress_bar, output_device=None,
patch_size=1, decode=False):
B, T, C, H, W = x.shape B, T, C, H, W = x.shape
if parallel: if parallel:
x = x.reshape(B*T, C, H, W) x = x.reshape(B*T, C, H, W)
if not decode and patch_size > 1:
x = F.pixel_unshuffle(x, patch_size)
# parallel over input timesteps, iterate over blocks # parallel over input timesteps, iterate over blocks
for b in tqdm(model, disable=not show_progress_bar): for b in tqdm(model, disable=not show_progress_bar):
if isinstance(b, MemBlock): if isinstance(b, MemBlock):
@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
x = b(x, mem) x = b(x, mem)
else: else:
x = b(x) x = b(x)
BT, C, H, W = x.shape if decode and patch_size > 1:
T = BT // B x = F.pixel_shuffle(x, patch_size)
x = x.view(B, T, C, H, W) x = x.view(B, x.shape[0] // B, *x.shape[1:])
x = x.to(output_device)
else: else:
out = [] out = []
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))]) # Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views).
# Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode).
work_queue = deque([TWorkItem(xt.squeeze(1), 0) for xt in x.chunk(T, dim=1)])
progress_bar = tqdm(range(T), disable=not show_progress_bar) progress_bar = tqdm(range(T), disable=not show_progress_bar)
mem = [None] * len(model) mem = [None] * len(model)
while work_queue: while work_queue:
xt, i = work_queue.popleft() xt, i = work_queue.popleft()
if i == 0: if i == 0:
progress_bar.update(1) progress_bar.update(1)
if not decode and patch_size > 1:
xt = F.pixel_unshuffle(xt, patch_size)
if i == len(model): if i == len(model):
out.append(xt) if decode and patch_size > 1:
xt = F.pixel_shuffle(xt, patch_size)
out.append(xt.to(output_device))
del xt del xt
else: else:
b = model[i] b = model[i]
@ -165,24 +176,20 @@ class TAEHV(nn.Module):
def encode(self, x, **kwargs): def encode(self, x, **kwargs):
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
if self.patch_size > 1:
B, T, C, H, W = x.shape
x = x.reshape(B * T, C, H, W)
x = F.pixel_unshuffle(x, self.patch_size)
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
if x.shape[1] % self.t_downscale != 0: if x.shape[1] % self.t_downscale != 0:
# pad at end to multiple of t_downscale # pad at end to multiple of t_downscale
n_pad = self.t_downscale - x.shape[1] % self.t_downscale n_pad = self.t_downscale - x.shape[1] % self.t_downscale
padding = x[:, -1:].repeat_interleave(n_pad, dim=1) padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
x = torch.cat([x, padding], 1) x = torch.cat([x, padding], 1)
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar,
patch_size=self.patch_size).movedim(2, 1)
return self.process_out(x) return self.process_out(x)
def decode(self, x, **kwargs): def decode(self, x, **kwargs):
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W] x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W] x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar,
if self.patch_size > 1: output_device=comfy.model_management.intermediate_device(),
x = F.pixel_shuffle(x, self.patch_size) patch_size=self.patch_size, decode=True)
return x[:, self.frames_to_trim:].movedim(2, 1) return x[:, self.frames_to_trim:].movedim(2, 1)

View File

@ -17,32 +17,79 @@ class Clamp(nn.Module):
return torch.tanh(x / 3) * 3 return torch.tanh(x / 3) * 3
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_in, n_out): def __init__(self, n_in: int, n_out: int, use_midblock_gn: bool = False):
super().__init__() super().__init__()
self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out)) self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() self.skip = comfy.ops.disable_weight_init.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.fuse = nn.ReLU() self.fuse = nn.ReLU()
def forward(self, x): if not use_midblock_gn:
self.pool = None
return
n_gn = n_in * 4
self.pool = nn.Sequential(
comfy.ops.disable_weight_init.Conv2d(n_in, n_gn, 1, bias=False),
comfy.ops.disable_weight_init.GroupNorm(4, n_gn),
nn.ReLU(inplace=True),
comfy.ops.disable_weight_init.Conv2d(n_gn, n_in, 1, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.pool is not None:
x = x + self.pool(x)
return self.fuse(self.conv(x) + self.skip(x)) return self.fuse(self.conv(x) + self.skip(x))
def Encoder(latent_channels=4): class Encoder(nn.Sequential):
return nn.Sequential( def __init__(self, latent_channels: int = 4, use_gn: bool = False):
conv(3, 64), Block(64, 64), super().__init__(
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(3, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64),
conv(64, latent_channels), conv(64, 64, stride=2, bias=False), Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn),
) conv(64, latent_channels),
)
class Decoder(nn.Sequential):
def __init__(self, latent_channels: int = 4, use_gn: bool = False):
super().__init__(
Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64, use_gn), Block(64, 64, use_gn), Block(64, 64, use_gn), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class DecoderFlux2(Decoder):
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
if latent_channels != 128 or not use_gn:
raise ValueError("Unexpected parameters for Flux2 TAE module")
super().__init__(latent_channels=32, use_gn=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
x = (
x
.reshape(B, 32, 2, 2, H, W)
.permute(0, 1, 4, 2, 5, 3)
.reshape(B, 32, H * 2, W * 2)
)
return super().forward(x)
class EncoderFlux2(Encoder):
def __init__(self, latent_channels: int = 128, use_gn: bool = True):
if latent_channels != 128 or not use_gn:
raise ValueError("Unexpected parameters for Flux2 TAE module")
super().__init__(latent_channels=32, use_gn=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
result = super().forward(x)
B, C, H, W = result.shape
return (
result
.reshape(B, C, H // 2, 2, W // 2, 2)
.permute(0, 1, 3, 5, 2, 4)
.reshape(B, 128, H // 2, W // 2)
)
def Decoder(latent_channels=4):
return nn.Sequential(
Clamp(), conv(latent_channels, 64), nn.ReLU(),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
Block(64, 64), conv(64, 3),
)
class TAESD(nn.Module): class TAESD(nn.Module):
latent_magnitude = 3 latent_magnitude = 3
@ -51,8 +98,15 @@ class TAESD(nn.Module):
def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4): def __init__(self, encoder_path=None, decoder_path=None, latent_channels=4):
"""Initialize pretrained TAESD on the given device from the given checkpoints.""" """Initialize pretrained TAESD on the given device from the given checkpoints."""
super().__init__() super().__init__()
self.taesd_encoder = Encoder(latent_channels=latent_channels) if latent_channels == 128:
self.taesd_decoder = Decoder(latent_channels=latent_channels) encoder_class = EncoderFlux2
decoder_class = DecoderFlux2
else:
encoder_class = Encoder
decoder_class = Decoder
self.taesd_encoder = encoder_class(latent_channels=latent_channels)
self.taesd_decoder = decoder_class(latent_channels=latent_channels)
self.vae_scale = torch.nn.Parameter(torch.tensor(1.0)) self.vae_scale = torch.nn.Parameter(torch.tensor(1.0))
self.vae_shift = torch.nn.Parameter(torch.tensor(0.0)) self.vae_shift = torch.nn.Parameter(torch.tensor(0.0))
if encoder_path is not None: if encoder_path is not None:
@ -61,19 +115,19 @@ class TAESD(nn.Module):
self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True)) self.taesd_decoder.load_state_dict(comfy.utils.load_torch_file(decoder_path, safe_load=True))
@staticmethod @staticmethod
def scale_latents(x): def scale_latents(x: torch.Tensor) -> torch.Tensor:
"""raw latents -> [0, 1]""" """raw latents -> [0, 1]"""
return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1) return x.div(2 * TAESD.latent_magnitude).add(TAESD.latent_shift).clamp(0, 1)
@staticmethod @staticmethod
def unscale_latents(x): def unscale_latents(x: torch.Tensor) -> torch.Tensor:
"""[0, 1] -> raw latents""" """[0, 1] -> raw latents"""
return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
def decode(self, x): def decode(self, x: torch.Tensor) -> torch.Tensor:
x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale) x_sample = self.taesd_decoder((x - self.vae_shift) * self.vae_scale)
x_sample = x_sample.sub(0.5).mul(2) x_sample = x_sample.sub(0.5).mul(2)
return x_sample return x_sample
def encode(self, x): def encode(self, x: torch.Tensor) -> torch.Tensor:
return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift return (self.taesd_encoder(x * 0.5 + 0.5) / self.vae_scale) + self.vae_shift

View File

@ -0,0 +1,6 @@
import comfy.text_encoders.sd3_clip
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)

View File

@ -9,6 +9,7 @@ from comfy_api.latest._input import (
CurveInput, CurveInput,
MonotoneCubicCurve, MonotoneCubicCurve,
LinearCurve, LinearCurve,
RangeInput,
) )
__all__ = [ __all__ = [
@ -21,4 +22,5 @@ __all__ = [
"CurveInput", "CurveInput",
"MonotoneCubicCurve", "MonotoneCubicCurve",
"LinearCurve", "LinearCurve",
"RangeInput",
] ]

View File

@ -1,5 +1,6 @@
from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput
from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve
from .range_types import RangeInput
from .video_types import VideoInput from .video_types import VideoInput
__all__ = [ __all__ = [
@ -12,4 +13,5 @@ __all__ = [
"CurveInput", "CurveInput",
"MonotoneCubicCurve", "MonotoneCubicCurve",
"LinearCurve", "LinearCurve",
"RangeInput",
] ]

View File

@ -0,0 +1,70 @@
from __future__ import annotations
import logging
import math
import numpy as np
logger = logging.getLogger(__name__)
class RangeInput:
"""Represents a levels/range adjustment: input range [min, max] with
optional midpoint (gamma control).
Generates a 1D LUT identical to GIMP's levels mapping:
1. Normalize input to [0, 1] using [min, max]
2. Apply gamma correction: pow(value, 1/gamma)
3. Clamp to [0, 1]
The midpoint field is a position in [0, 1] representing where the
midtone falls within [min, max]. It maps to gamma via:
gamma = -log2(midpoint)
So midpoint=0.5 gamma=1.0 (linear).
"""
def __init__(self, min_val: float, max_val: float, midpoint: float | None = None):
self.min_val = min_val
self.max_val = max_val
self.midpoint = midpoint
@staticmethod
def from_raw(data) -> RangeInput:
if isinstance(data, RangeInput):
return data
if isinstance(data, dict):
return RangeInput(
min_val=float(data.get("min", 0.0)),
max_val=float(data.get("max", 1.0)),
midpoint=float(data["midpoint"]) if data.get("midpoint") is not None else None,
)
raise TypeError(f"Cannot convert {type(data)} to RangeInput")
def to_lut(self, size: int = 256) -> np.ndarray:
"""Generate a float64 lookup table mapping [0, 1] input through this
levels adjustment.
The LUT maps normalized input values (0..1) to output values (0..1),
matching the GIMP levels formula.
"""
xs = np.linspace(0.0, 1.0, size, dtype=np.float64)
in_range = self.max_val - self.min_val
if abs(in_range) < 1e-10:
return np.where(xs >= self.min_val, 1.0, 0.0).astype(np.float64)
# Normalize: map [min, max] → [0, 1]
result = (xs - self.min_val) / in_range
result = np.clip(result, 0.0, 1.0)
# Gamma correction from midpoint
if self.midpoint is not None and self.midpoint > 0 and self.midpoint != 0.5:
gamma = max(-math.log2(self.midpoint), 0.001)
inv_gamma = 1.0 / gamma
mask = result > 0
result[mask] = np.power(result[mask], inv_gamma)
return result
def __repr__(self) -> str:
mid = f", midpoint={self.midpoint}" if self.midpoint is not None else ""
return f"RangeInput(min={self.min_val}, max={self.max_val}{mid})"

View File

@ -12,6 +12,7 @@ import numpy as np
import math import math
import torch import torch
from .._util import VideoContainer, VideoCodec, VideoComponents from .._util import VideoContainer, VideoCodec, VideoComponents
import logging
def container_to_output_format(container_format: str | None) -> str | None: def container_to_output_format(container_format: str | None) -> str | None:
@ -238,64 +239,125 @@ class VideoFromFile(VideoInput):
start_time = max(self._get_raw_duration() + self.__start_time, 0) start_time = max(self._get_raw_duration() + self.__start_time, 0)
else: else:
start_time = self.__start_time start_time = self.__start_time
# Get video frames # Get video frames
frames = [] frames = []
audio_frames = []
alphas = None
start_pts = int(start_time / video_stream.time_base) start_pts = int(start_time / video_stream.time_base)
end_pts = int((start_time + self.__duration) / video_stream.time_base) end_pts = int((start_time + self.__duration) / video_stream.time_base)
container.seek(start_pts, stream=video_stream)
for frame in container.decode(video_stream):
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
break
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) if start_pts != 0:
container.seek(start_pts, stream=video_stream)
image_format = 'gbrpf32le'
process_image_format = lambda a: a
audio = None
streams = [video_stream]
has_first_audio_frame = False
checked_alpha = False
# Default to False so we decode until EOF if duration is 0
video_done = False
audio_done = True
if len(container.streams.audio):
audio_stream = container.streams.audio[-1]
streams += [audio_stream]
resampler = av.audio.resampler.AudioResampler(format='fltp')
audio_done = False
for packet in container.demux(*streams):
if video_done and audio_done:
break
if packet.stream.type == "video":
if video_done:
continue
try:
for frame in packet.decode():
if frame.pts < start_pts:
continue
if self.__duration and frame.pts >= end_pts:
video_done = True
break
if not checked_alpha:
alpha_channel = False
for comp in frame.format.components:
if comp.is_alpha or frame.format.name == "pal8":
alphas = []
alpha_channel = True
break
if frame.format.name in ("yuvj420p", "rgb24", "rgba", "pal8"):
process_image_format = lambda a: a.float() / 255.0
if alpha_channel:
image_format = 'rgba'
else:
image_format = 'rgb24'
else:
process_image_format = lambda a: a
if alpha_channel:
image_format = 'gbrapf32le'
else:
image_format = 'gbrpf32le'
checked_alpha = True
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
if frame.rotation != 0:
k = int(round(frame.rotation // 90))
img = np.rot90(img, k=k, axes=(0, 1)).copy()
if alphas is None:
frames.append(torch.from_numpy(img))
else:
frames.append(torch.from_numpy(img[..., :-1]))
alphas.append(torch.from_numpy(img[..., -1:]))
except av.error.InvalidDataError:
logging.info("pyav decode error")
elif packet.stream.type == "audio":
if audio_done:
continue
aframes = itertools.chain.from_iterable(
map(resampler.resample, packet.decode())
)
for frame in aframes:
if self.__duration and frame.time > start_time + self.__duration:
audio_done = True
break
if not has_first_audio_frame:
offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
if to_skip < frame.samples:
has_first_audio_frame = True
audio_frames.append(frame.to_ndarray()[..., to_skip:])
else:
audio_frames.append(frame.to_ndarray())
images = process_image_format(torch.stack(frames)) if len(frames) > 0 else torch.zeros(0, 0, 0, 3)
if alphas is not None:
alphas = process_image_format(torch.stack(alphas)) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1)
# Get frame rate # Get frame rate
frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
# Get audio if available if len(audio_frames) > 0:
audio = None audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
container.seek(start_pts, stream=video_stream) if self.__duration:
# Use last stream for consistency audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
if len(container.streams.audio):
audio_stream = container.streams.audio[-1]
audio_frames = []
resample = av.audio.resampler.AudioResampler(format='fltp').resample
frames = itertools.chain.from_iterable(
map(resample, container.decode(audio_stream))
)
has_first_frame = False audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
for frame in frames: audio = AudioInput({
offset_seconds = start_time - frame.pts * audio_stream.time_base "waveform": audio_tensor,
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate)) "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
if to_skip < frame.samples: })
has_first_frame = True
break
if has_first_frame:
audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames:
if self.__duration and frame.time > start_time + self.__duration:
break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0:
audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
if self.__duration:
audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
audio = AudioInput({
"waveform": audio_tensor,
"sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
})
metadata = container.metadata metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata)
def get_components(self) -> VideoComponents: def get_components(self) -> VideoComponents:
if isinstance(self.__file, io.BytesIO): if isinstance(self.__file, io.BytesIO):

View File

@ -1266,6 +1266,43 @@ class Histogram(ComfyTypeIO):
Type = list[int] Type = list[int]
@comfytype(io_type="RANGE")
class Range(ComfyTypeIO):
from comfy_api.input import RangeInput
if TYPE_CHECKING:
Type = RangeInput
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None,
display: str=None,
gradient_stops: list=None,
show_midpoint: bool=None,
midpoint_scale: str=None,
value_min: float=None,
value_max: float=None,
advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = {"min": 0.0, "max": 1.0}
self.display = display
self.gradient_stops = gradient_stops
self.show_midpoint = show_midpoint
self.midpoint_scale = midpoint_scale
self.value_min = value_min
self.value_max = value_max
def as_dict(self):
return super().as_dict() | prune_dict({
"display": self.display,
"gradient_stops": self.gradient_stops,
"show_midpoint": self.show_midpoint,
"midpoint_scale": self.midpoint_scale,
"value_min": self.value_min,
"value_max": self.value_max,
})
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -2276,5 +2313,6 @@ __all__ = [
"BoundingBox", "BoundingBox",
"Curve", "Curve",
"Histogram", "Histogram",
"Range",
"NodeReplace", "NodeReplace",
] ]

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from fractions import Fraction from fractions import Fraction
from typing import Optional from typing import Optional
from .._input import ImageInput, AudioInput from .._input import ImageInput, AudioInput, MaskInput
class VideoCodec(str, Enum): class VideoCodec(str, Enum):
AUTO = "auto" AUTO = "auto"
@ -48,5 +48,4 @@ class VideoComponents:
frame_rate: Fraction frame_rate: Fraction
audio: Optional[AudioInput] = None audio: Optional[AudioInput] = None
metadata: Optional[dict] = None metadata: Optional[dict] = None
alpha: Optional[MaskInput] = None

View File

@ -118,7 +118,7 @@ class Wan27ReferenceVideoInputField(BaseModel):
class Wan27ReferenceVideoParametersField(BaseModel): class Wan27ReferenceVideoParametersField(BaseModel):
resolution: str = Field(...) resolution: str = Field(...)
ratio: str | None = Field(None) ratio: str | None = Field(None)
duration: int = Field(5, ge=2, le=10) duration: int = Field(5, ge=2, le=15)
watermark: bool = Field(False) watermark: bool = Field(False)
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)
@ -157,7 +157,7 @@ class Wan27VideoEditInputField(BaseModel):
class Wan27VideoEditParametersField(BaseModel): class Wan27VideoEditParametersField(BaseModel):
resolution: str = Field(...) resolution: str = Field(...)
ratio: str | None = Field(None) ratio: str | None = Field(None)
duration: int = Field(0) duration: int | None = Field(0)
audio_setting: str = Field("auto") audio_setting: str = Field("auto")
watermark: bool = Field(False) watermark: bool = Field(False)
seed: int = Field(..., ge=0, le=2147483647) seed: int = Field(..., ge=0, le=2147483647)

View File

@ -276,6 +276,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe
cls, cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
) )
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -862,7 +863,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
), ),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"storyboards", "storyboards",
options=[ options=[
@ -904,12 +905,13 @@ class OmniProTextToVideoNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr=""" expr="""
( (
$mode := (widgets.resolution = "720p") ? "std" : "pro"; $res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3"); $isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio; $audio := $isV3 and widgets.generate_audio;
$rates := $audio $rates := $audio
? {"std": 0.112, "pro": 0.14} ? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112}; : {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
) )
""", """,
@ -934,6 +936,8 @@ class OmniProTextToVideoNode(IO.ComfyNode):
raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.") raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.")
if generate_audio: if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.") raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1": if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.") raise ValueError("kling-video-o1 does not support storyboards.")
@ -963,6 +967,12 @@ class OmniProTextToVideoNode(IO.ComfyNode):
f"must equal the global duration ({duration}s)." f"must equal the global duration ({duration}s)."
) )
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -972,7 +982,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
prompt=prompt, prompt=prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
duration=str(duration), duration=str(duration),
mode="pro" if resolution == "1080p" else "std", mode=mode,
multi_shot=multi_shot, multi_shot=multi_shot,
multi_prompt=multi_prompt_list, multi_prompt=multi_prompt_list,
shot_type="customize" if multi_shot else None, shot_type="customize" if multi_shot else None,
@ -1014,7 +1024,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
optional=True, optional=True,
tooltip="Up to 6 additional reference images.", tooltip="Up to 6 additional reference images.",
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"storyboards", "storyboards",
options=[ options=[
@ -1061,12 +1071,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr=""" expr="""
( (
$mode := (widgets.resolution = "720p") ? "std" : "pro"; $res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3"); $isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio; $audio := $isV3 and widgets.generate_audio;
$rates := $audio $rates := $audio
? {"std": 0.112, "pro": 0.14} ? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112}; : {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
) )
""", """,
@ -1093,6 +1104,8 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
if generate_audio: if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.") raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1": if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.") raise ValueError("kling-video-o1 does not support storyboards.")
@ -1161,6 +1174,12 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
image_list.append(OmniParamImage(image_url=i)) image_list.append(OmniParamImage(image_url=i))
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -1170,7 +1189,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
prompt=prompt, prompt=prompt,
duration=str(duration), duration=str(duration),
image_list=image_list, image_list=image_list,
mode="pro" if resolution == "1080p" else "std", mode=mode,
sound="on" if generate_audio else "off", sound="on" if generate_audio else "off",
multi_shot=multi_shot, multi_shot=multi_shot,
multi_prompt=multi_prompt_list, multi_prompt=multi_prompt_list,
@ -1204,7 +1223,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
"reference_images", "reference_images",
tooltip="Up to 7 reference images.", tooltip="Up to 7 reference images.",
), ),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"storyboards", "storyboards",
options=[ options=[
@ -1251,12 +1270,13 @@ class OmniProImageToVideoNode(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr=""" expr="""
( (
$mode := (widgets.resolution = "720p") ? "std" : "pro"; $res := widgets.resolution;
$mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro");
$isV3 := $contains(widgets.model_name, "v3"); $isV3 := $contains(widgets.model_name, "v3");
$audio := $isV3 and widgets.generate_audio; $audio := $isV3 and widgets.generate_audio;
$rates := $audio $rates := $audio
? {"std": 0.112, "pro": 0.14} ? {"std": 0.112, "pro": 0.14, "4k": 0.42}
: {"std": 0.084, "pro": 0.112}; : {"std": 0.084, "pro": 0.112, "4k": 0.42};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
) )
""", """,
@ -1282,6 +1302,8 @@ class OmniProImageToVideoNode(IO.ComfyNode):
raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
if generate_audio: if generate_audio:
raise ValueError("kling-video-o1 does not support audio generation.") raise ValueError("kling-video-o1 does not support audio generation.")
if resolution == "4k":
raise ValueError("kling-video-o1 does not support 4k resolution.")
stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
if stories_enabled and model_name == "kling-video-o1": if stories_enabled and model_name == "kling-video-o1":
raise ValueError("kling-video-o1 does not support storyboards.") raise ValueError("kling-video-o1 does not support storyboards.")
@ -1320,6 +1342,12 @@ class OmniProImageToVideoNode(IO.ComfyNode):
image_list: list[OmniParamImage] = [] image_list: list[OmniParamImage] = []
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i)) image_list.append(OmniParamImage(image_url=i))
if resolution == "4k":
mode = "4k"
elif resolution == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@ -1330,7 +1358,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
duration=str(duration), duration=str(duration),
image_list=image_list, image_list=image_list,
mode="pro" if resolution == "1080p" else "std", mode=mode,
sound="on" if generate_audio else "off", sound="on" if generate_audio else "off",
multi_shot=multi_shot, multi_shot=multi_shot,
multi_prompt=multi_prompt_list, multi_prompt=multi_prompt_list,
@ -2860,7 +2888,7 @@ class KlingVideoNode(IO.ComfyNode):
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"kling-v3", "kling-v3",
[ [
IO.Combo.Input("resolution", options=["1080p", "720p"]), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
IO.Combo.Input( IO.Combo.Input(
"aspect_ratio", "aspect_ratio",
options=["16:9", "9:16", "1:1"], options=["16:9", "9:16", "1:1"],
@ -2913,7 +2941,11 @@ class KlingVideoNode(IO.ComfyNode):
), ),
expr=""" expr="""
( (
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; $rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off"; $audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio); $rate := $lookup($lookup($rates, $res), $audio);
@ -2943,7 +2975,12 @@ class KlingVideoNode(IO.ComfyNode):
start_frame: Input.Image | None = None, start_frame: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
_ = seed _ = seed
mode = "pro" if model["resolution"] == "1080p" else "std" if model["resolution"] == "4k":
mode = "4k"
elif model["resolution"] == "1080p":
mode = "pro"
else:
mode = "std"
custom_multi_shot = False custom_multi_shot = False
if multi_shot["multi_shot"] == "disabled": if multi_shot["multi_shot"] == "disabled":
shot_type = None shot_type = None
@ -3025,6 +3062,7 @@ class KlingVideoNode(IO.ComfyNode):
cls, cls,
ApiEndpoint(path=poll_path), ApiEndpoint(path=poll_path),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
) )
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
@ -3057,7 +3095,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"kling-v3", "kling-v3",
[ [
IO.Combo.Input("resolution", options=["1080p", "720p"]), IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"),
], ],
), ),
], ],
@ -3089,7 +3127,11 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
), ),
expr=""" expr="""
( (
$rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; $rates := {
"4k": {"off": 0.42, "on": 0.42},
"1080p": {"off": 0.112, "on": 0.168},
"720p": {"off": 0.084, "on": 0.126}
};
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$audio := widgets.generate_audio ? "on" : "off"; $audio := widgets.generate_audio ? "on" : "off";
$rate := $lookup($lookup($rates, $res), $audio); $rate := $lookup($lookup($rates, $res), $audio);
@ -3118,6 +3160,12 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame") image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame")
image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame")
if model["resolution"] == "4k":
mode = "4k"
elif model["resolution"] == "1080p":
mode = "pro"
else:
mode = "std"
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
@ -3127,7 +3175,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
image=image_url, image=image_url,
image_tail=image_tail_url, image_tail=image_tail_url,
prompt=prompt, prompt=prompt,
mode="pro" if model["resolution"] == "1080p" else "std", mode=mode,
duration=str(duration), duration=str(duration),
sound="on" if generate_audio else "off", sound="on" if generate_audio else "off",
), ),
@ -3140,6 +3188,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
cls, cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
max_poll_attempts=280,
status_extractor=lambda r: (r.data.task_status if r.data else None), status_extractor=lambda r: (r.data.task_status if r.data else None),
) )
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))

View File

@ -415,8 +415,9 @@ class OpenAIGPTImage1(IO.ComfyNode):
"1152x2048", "1152x2048",
"3840x2160", "3840x2160",
"2160x3840", "2160x3840",
"Custom",
], ],
tooltip="Image size", tooltip="Image size. Select 'Custom' to use the custom width and height (GPT Image 2 only).",
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
@ -445,6 +446,26 @@ class OpenAIGPTImage1(IO.ComfyNode):
default="gpt-image-2", default="gpt-image-2",
optional=True, optional=True,
), ),
IO.Int.Input(
"custom_width",
default=1024,
min=1024,
max=3840,
step=16,
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).",
optional=True,
advanced=True,
),
IO.Int.Input(
"custom_height",
default=1024,
min=1024,
max=3840,
step=16,
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16 (GPT Image 2 only).",
optional=True,
advanced=True,
),
], ],
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
@ -471,9 +492,9 @@ class OpenAIGPTImage1(IO.ComfyNode):
"high": [0.133, 0.22] "high": [0.133, 0.22]
}, },
"gpt-image-2": { "gpt-image-2": {
"low": [0.0048, 0.012], "low": [0.0048, 0.019],
"medium": [0.041, 0.112], "medium": [0.041, 0.168],
"high": [0.165, 0.43] "high": [0.165, 0.67]
} }
}; };
$range := $lookup($lookup($ranges, widgets.model), widgets.quality); $range := $lookup($lookup($ranges, widgets.model), widgets.quality);
@ -503,6 +524,8 @@ class OpenAIGPTImage1(IO.ComfyNode):
mask: Input.Image | None = None, mask: Input.Image | None = None,
n: int = 1, n: int = 1,
size: str = "1024x1024", size: str = "1024x1024",
custom_width: int = 1024,
custom_height: int = 1024,
model: str = "gpt-image-1", model: str = "gpt-image-1",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
@ -510,7 +533,25 @@ class OpenAIGPTImage1(IO.ComfyNode):
if mask is not None and image is None: if mask is not None and image is None:
raise ValueError("Cannot use a mask without an input image") raise ValueError("Cannot use a mask without an input image")
if model in ("gpt-image-1", "gpt-image-1.5"): if size == "Custom":
if model != "gpt-image-2":
raise ValueError("Custom resolution is only supported by GPT Image 2 model")
if custom_width % 16 != 0 or custom_height % 16 != 0:
raise ValueError(f"Custom width and height must be multiples of 16, got {custom_width}x{custom_height}")
if max(custom_width, custom_height) > 3840:
raise ValueError(f"Custom resolution max edge must be <= 3840, got {custom_width}x{custom_height}")
ratio = max(custom_width, custom_height) / min(custom_width, custom_height)
if ratio > 3:
raise ValueError(
f"Custom resolution aspect ratio must not exceed 3:1, got {custom_width}x{custom_height}"
)
total_pixels = custom_width * custom_height
if not 655_360 <= total_pixels <= 8_294_400:
raise ValueError(
f"Custom resolution total pixels must be between 655,360 and 8,294,400, got {total_pixels}"
)
size = f"{custom_width}x{custom_height}"
elif model in ("gpt-image-1", "gpt-image-1.5"):
if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"): if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"):
raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model") raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model")

View File

@ -33,9 +33,13 @@ class OpenAIVideoSora2(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="OpenAIVideoSora2", node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video", display_name="OpenAI Sora - Video (Deprecated)",
category="api node/video/Sora", category="api node/video/Sora",
description="OpenAI video and audio generation.", description=(
"OpenAI video and audio generation.\n\n"
"DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. "
"This node will be removed from ComfyUI at that time."
),
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",

View File

@ -1646,6 +1646,557 @@ class Wan2ReferenceVideoApi(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url)) return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
class HappyHorseTextToVideoApi(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="HappyHorseTextToVideoApi",
display_name="HappyHorse Text to Video",
category="api node/video/Wan",
description="Generates a video based on a text prompt using the HappyHorse model.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.0-t2v",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the elements and visual features. "
"Supports English and Chinese.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Combo.Input(
"ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4"],
),
IO.Int.Input(
"duration",
default=5,
min=3,
max=15,
step=1,
display_mode=IO.NumberDisplay.number,
),
],
),
],
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
advanced=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps * $dur }
)
""",
),
)
@classmethod
async def execute(
cls,
model: dict,
seed: int,
watermark: bool,
):
validate_string(model["prompt"], strip_whitespace=False, min_length=1)
initial_response = await sync_op(
cls,
ApiEndpoint(
path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
method="POST",
),
response_model=TaskCreationResponse,
data=Wan27Text2VideoTaskCreationRequest(
model=model["model"],
input=Text2VideoInputField(
prompt=model["prompt"],
negative_prompt=None,
),
parameters=Wan27Text2VideoParametersField(
resolution=model["resolution"],
ratio=model["ratio"],
duration=model["duration"],
seed=seed,
watermark=watermark,
),
),
)
if not initial_response.output:
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=VideoTaskStatusResponse,
status_extractor=lambda x: x.output.task_status,
poll_interval=7,
)
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
class HappyHorseImageToVideoApi(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="HappyHorseImageToVideoApi",
display_name="HappyHorse Image to Video",
category="api node/video/Wan",
description="Generate a video from a first-frame image using the HappyHorse model.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.0-i2v",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the elements and visual features. "
"Supports English and Chinese.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Int.Input(
"duration",
default=5,
min=3,
max=15,
step=1,
display_mode=IO.NumberDisplay.number,
),
],
),
],
),
IO.Image.Input(
"first_frame",
tooltip="First frame image. The output aspect ratio is derived from this image.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
advanced=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps * $dur }
)
""",
),
)
@classmethod
async def execute(
cls,
model: dict,
first_frame: Input.Image,
seed: int,
watermark: bool,
):
media = [
Wan27MediaItem(
type="first_frame",
url=await upload_image_to_comfyapi(cls, image=first_frame),
)
]
initial_response = await sync_op(
cls,
ApiEndpoint(
path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
method="POST",
),
response_model=TaskCreationResponse,
data=Wan27ImageToVideoTaskCreationRequest(
model=model["model"],
input=Wan27ImageToVideoInputField(
prompt=model["prompt"] or None,
negative_prompt=None,
media=media,
),
parameters=Wan27ImageToVideoParametersField(
resolution=model["resolution"],
duration=model["duration"],
seed=seed,
watermark=watermark,
),
),
)
if not initial_response.output:
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=VideoTaskStatusResponse,
status_extractor=lambda x: x.output.task_status,
poll_interval=7,
)
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
class HappyHorseVideoEditApi(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="HappyHorseVideoEditApi",
display_name="HappyHorse Video Edit",
category="api node/video/Wan",
description="Edit a video using text instructions or reference images with the HappyHorse model. "
"Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.0-video-edit",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Editing instructions or style transfer requirements.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Combo.Input(
"ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4"],
tooltip="Aspect ratio. If not changed, approximates the input video ratio.",
),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_image"),
names=[
"image1",
"image2",
"image3",
"image4",
"image5",
],
min=0,
),
),
],
),
],
),
IO.Video.Input(
"video",
tooltip="The video to edit.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
advanced=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps, "format": { "suffix": "/second" } }
)
""",
),
)
@classmethod
async def execute(
cls,
model: dict,
video: Input.Video,
seed: int,
watermark: bool,
):
validate_string(model["prompt"], strip_whitespace=False, min_length=1)
validate_video_duration(video, min_duration=3, max_duration=60)
media = [Wan27MediaItem(type="video", url=await upload_video_to_comfyapi(cls, video))]
reference_images = model.get("reference_images", {})
for key in reference_images:
media.append(
Wan27MediaItem(
type="reference_image", url=await upload_image_to_comfyapi(cls, image=reference_images[key])
)
)
initial_response = await sync_op(
cls,
ApiEndpoint(
path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
method="POST",
),
response_model=TaskCreationResponse,
data=Wan27VideoEditTaskCreationRequest(
model=model["model"],
input=Wan27VideoEditInputField(prompt=model["prompt"], media=media),
parameters=Wan27VideoEditParametersField(
resolution=model["resolution"],
ratio=model["ratio"],
duration=None,
watermark=watermark,
seed=seed,
),
),
)
if not initial_response.output:
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=VideoTaskStatusResponse,
status_extractor=lambda x: x.output.task_status,
poll_interval=7,
)
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
class HappyHorseReferenceVideoApi(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="HappyHorseReferenceVideoApi",
display_name="HappyHorse Reference to Video",
category="api node/video/Wan",
description="Generate a video featuring a person or object from reference materials with the HappyHorse "
"model. Supports single-character performances and multi-character interactions.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"happyhorse-1.0-r2v",
[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the video. Use identifiers such as 'character1' and "
"'character2' to refer to the reference characters.",
),
IO.Combo.Input(
"resolution",
options=["720P", "1080P"],
),
IO.Combo.Input(
"ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4"],
),
IO.Int.Input(
"duration",
default=5,
min=3,
max=15,
step=1,
display_mode=IO.NumberDisplay.number,
),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("reference_image"),
names=[
"image1",
"image2",
"image3",
"image4",
"image5",
"image6",
"image7",
"image8",
"image9",
],
min=1,
),
),
],
),
],
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed to use for generation.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
advanced=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution", "model.duration"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$dur := $lookup(widgets, "model.duration");
$ppsTable := { "720p": 0.14, "1080p": 0.24 };
$pps := $lookup($ppsTable, $res);
{ "type": "usd", "usd": $pps * $dur }
)
""",
),
)
@classmethod
async def execute(
cls,
model: dict,
seed: int,
watermark: bool,
):
validate_string(model["prompt"], strip_whitespace=False, min_length=1)
media = []
reference_images = model.get("reference_images", {})
for key in reference_images:
media.append(
Wan27MediaItem(
type="reference_image",
url=await upload_image_to_comfyapi(cls, image=reference_images[key]),
)
)
if not media:
raise ValueError("At least one reference reference image must be provided.")
initial_response = await sync_op(
cls,
ApiEndpoint(
path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
method="POST",
),
response_model=TaskCreationResponse,
data=Wan27ReferenceVideoTaskCreationRequest(
model=model["model"],
input=Wan27ReferenceVideoInputField(
prompt=model["prompt"],
negative_prompt=None,
media=media,
),
parameters=Wan27ReferenceVideoParametersField(
resolution=model["resolution"],
ratio=model["ratio"],
duration=model["duration"],
watermark=watermark,
seed=seed,
),
),
)
if not initial_response.output:
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=VideoTaskStatusResponse,
status_extractor=lambda x: x.output.task_status,
poll_interval=7,
)
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
class WanApiExtension(ComfyExtension): class WanApiExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1660,6 +2211,10 @@ class WanApiExtension(ComfyExtension):
Wan2VideoContinuationApi, Wan2VideoContinuationApi,
Wan2VideoEditApi, Wan2VideoEditApi,
Wan2ReferenceVideoApi, Wan2ReferenceVideoApi,
HappyHorseTextToVideoApi,
HappyHorseImageToVideoApi,
HappyHorseVideoEditApi,
HappyHorseReferenceVideoApi,
] ]

View File

@ -5,6 +5,7 @@ import psutil
import time import time
import torch import torch
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from comfy.model_patcher import ModelPatcher
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -523,13 +524,15 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set_local(node_id, value) super().set_local(node_id, value)
def ram_release(self, target): def ram_release(self, target, free_active=False):
if psutil.virtual_memory().available >= target: if psutil.virtual_memory().available >= target:
return return
clean_list = [] clean_list = []
for key, cache_entry in self.cache.items(): for key, cache_entry in self.cache.items():
if not free_active and self.used_generation[key] == self.generation:
continue
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@ -542,6 +545,9 @@ class RAMPressureCache(LRUCache):
scan_list_for_ram_usage(output) scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
ram_usage += output.numel() * output.element_size() ram_usage += output.numel() * output.element_size()
elif isinstance(output, ModelPatcher) and self.used_generation[key] != self.generation:
#old ModelPatchers are the first to go
ram_usage = 1e30
scan_list_for_ram_usage(cache_entry.outputs) scan_list_for_ram_usage(cache_entry.outputs)
oom_score *= ram_usage oom_score *= ram_usage

View File

@ -637,7 +637,7 @@ class SaveGLB(IO.ComfyNode):
], ],
tooltip="Mesh or 3D file to save", tooltip="Mesh or 3D file to save",
), ),
IO.String.Input("filename_prefix", default="mesh/ComfyUI"), IO.String.Input("filename_prefix", default="3d/ComfyUI"),
], ],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo] hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
) )

View File

@ -1,6 +1,7 @@
import nodes import nodes
import node_helpers import node_helpers
import torch import torch
import torchaudio
import comfy.model_management import comfy.model_management
import comfy.model_sampling import comfy.model_sampling
import comfy.samplers import comfy.samplers
@ -711,7 +712,14 @@ class LTXVReferenceAudio(io.ComfyNode):
@classmethod @classmethod
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
# Encode reference audio to latents and patchify # Encode reference audio to latents and patchify
audio_latents = audio_vae.encode(reference_audio) sample_rate = reference_audio["sample_rate"]
vae_sample_rate = getattr(audio_vae, "audio_sample_rate", 44100)
if vae_sample_rate != sample_rate:
waveform = torchaudio.functional.resample(reference_audio["waveform"], sample_rate, vae_sample_rate)
else:
waveform = reference_audio["waveform"]
audio_latents = audio_vae.encode(waveform.movedim(1, -1))
b, c, t, f = audio_latents.shape b, c, t, f = audio_latents.shape
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
ref_audio = {"tokens": ref_tokens} ref_audio = {"tokens": ref_tokens}

View File

@ -2,6 +2,7 @@ import numpy as np
import scipy.ndimage import scipy.ndimage
import torch import torch
import comfy.utils import comfy.utils
import comfy.model_management
import node_helpers import node_helpers
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, UI from comfy_api.latest import ComfyExtension, IO, UI
@ -188,7 +189,7 @@ class SolidMask(IO.ComfyNode):
@classmethod @classmethod
def execute(cls, value, width, height) -> IO.NodeOutput: def execute(cls, value, width, height) -> IO.NodeOutput:
out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") out = torch.full((1, height, width), value, dtype=torch.float32, device=comfy.model_management.intermediate_device())
return IO.NodeOutput(out) return IO.NodeOutput(out)
solid = execute # TODO: remove solid = execute # TODO: remove
@ -262,6 +263,7 @@ class MaskComposite(IO.ComfyNode):
def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput: def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput:
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
source = source.reshape((-1, source.shape[-2], source.shape[-1])) source = source.reshape((-1, source.shape[-2], source.shape[-1]))
source = source.to(output.device)
left, top = (x, y,) left, top = (x, y,)
right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]))

View File

@ -1,5 +1,6 @@
import json import json
from comfy.comfy_types.node_typing import IO from comfy.comfy_types.node_typing import IO
import torch
# Preview Any - original implement from # Preview Any - original implement from
# https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py # https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py
@ -19,6 +20,7 @@ class PreviewAny():
SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"] SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"]
def main(self, source=None): def main(self, source=None):
torch.set_printoptions(edgeitems=6)
value = 'None' value = 'None'
if isinstance(source, str): if isinstance(source, str):
value = source value = source
@ -33,6 +35,7 @@ class PreviewAny():
except Exception: except Exception:
value = 'source exists, but could not be serialized.' value = 'source exists, but could not be serialized.'
torch.set_printoptions()
return {"ui": {"text": (value,)}, "result": (value,)} return {"ui": {"text": (value,)}, "result": (value,)}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {

View File

@ -54,7 +54,7 @@ class EmptySD3LatentImage(io.ComfyNode):
@classmethod @classmethod
def execute(cls, width, height, batch_size=1) -> io.NodeOutput: def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8}) return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
generate = execute # TODO: remove generate = execute # TODO: remove

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.19.3" __version__ = "0.20.1"

View File

@ -779,7 +779,7 @@ class PromptExecutor:
if self.cache_type == CacheType.RAM_PRESSURE: if self.cache_type == CacheType.RAM_PRESSURE:
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
comfy.memory_management.extra_ram_release(ram_headroom) ram_release_callback(ram_headroom, free_active=True)
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
# Send cached UI for intermediate output nodes that weren't executed # Send cached UI for intermediate output nodes that weren't executed
@ -811,11 +811,30 @@ class PromptExecutor:
self._notify_prompt_lifecycle("end", prompt_id) self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated): async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
if visiting is None:
visiting = []
unique_id = item unique_id = item
if unique_id in validated: if unique_id in validated:
return validated[unique_id] return validated[unique_id]
if unique_id in visiting:
cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id]
cycle_nodes = list(dict.fromkeys(cycle_path_nodes))
cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes)
for node_id in cycle_nodes:
validated[node_id] = (False, [{
"type": "dependency_cycle",
"message": "Dependency cycle detected",
"details": cycle_path,
"extra_info": {
"node_id": node_id,
"cycle_nodes": cycle_nodes,
}
}], node_id)
return validated[unique_id]
inputs = prompt[unique_id]['inputs'] inputs = prompt[unique_id]['inputs']
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -899,7 +918,11 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
try: try:
r = await validate_inputs(prompt_id, prompt, o_id, validated) visiting.append(unique_id)
try:
r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting)
finally:
visiting.pop()
if r[0] is False: if r[0] is False:
# `r` will be set in `validated[o_id]` already # `r` will be set in `validated[o_id]` already
valid = False valid = False
@ -1048,10 +1071,13 @@ async def validate_inputs(prompt_id, prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if len(errors) > 0 or valid is not True: ret = validated.get(unique_id, (True, [], unique_id))
ret = (False, errors, unique_id) # Recursive cycle detection may have already populated an error on us. Join it.
else: ret = (
ret = (True, [], unique_id) ret[0] and valid is True and not errors,
ret[1] + [error for error in errors if error not in ret[1]],
unique_id,
)
validated[unique_id] = ret validated[unique_id] = ret
return ret return ret

View File

@ -32,7 +32,7 @@ import comfy.controlnet
from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator from comfy.comfy_types import IO, ComfyNodeABC, InputTypeDict, FileLocator
from comfy_api.internal import register_versions, ComfyAPIWithVersion from comfy_api.internal import register_versions, ComfyAPIWithVersion
from comfy_api.version_list import supported_versions from comfy_api.version_list import supported_versions
from comfy_api.latest import io, ComfyExtension from comfy_api.latest import io, ComfyExtension, InputImpl
import comfy.clip_vision import comfy.clip_vision
@ -728,50 +728,26 @@ class LoraLoaderModelOnly(LoraLoader):
class VAELoader: class VAELoader:
video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"] video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5", "taeltx_2"]
image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] image_taes = ["taesd", "taesdxl", "taesd3", "taef1", "taef2"]
@staticmethod @staticmethod
def vae_list(s): def vae_list(s):
vaes = folder_paths.get_filename_list("vae") vaes = folder_paths.get_filename_list("vae")
approx_vaes = folder_paths.get_filename_list("vae_approx") approx_vaes = folder_paths.get_filename_list("vae_approx")
sdxl_taesd_enc = False have_img_encoder, have_img_decoder = set(), set()
sdxl_taesd_dec = False
sd1_taesd_enc = False
sd1_taesd_dec = False
sd3_taesd_enc = False
sd3_taesd_dec = False
f1_taesd_enc = False
f1_taesd_dec = False
for v in approx_vaes: for v in approx_vaes:
if v.startswith("taesd_decoder."): parts = v.split("_", 1)
sd1_taesd_dec = True if len(parts) != 2 or parts[0] not in s.image_taes:
elif v.startswith("taesd_encoder."):
sd1_taesd_enc = True
elif v.startswith("taesdxl_decoder."):
sdxl_taesd_dec = True
elif v.startswith("taesdxl_encoder."):
sdxl_taesd_enc = True
elif v.startswith("taesd3_decoder."):
sd3_taesd_dec = True
elif v.startswith("taesd3_encoder."):
sd3_taesd_enc = True
elif v.startswith("taef1_encoder."):
f1_taesd_dec = True
elif v.startswith("taef1_decoder."):
f1_taesd_enc = True
else:
for tae in s.video_taes: for tae in s.video_taes:
if v.startswith(tae): if v.startswith(tae):
vaes.append(v) vaes.append(v)
break
if sd1_taesd_dec and sd1_taesd_enc: continue
vaes.append("taesd") if parts[1].startswith("encoder."):
if sdxl_taesd_dec and sdxl_taesd_enc: have_img_encoder.add(parts[0])
vaes.append("taesdxl") elif parts[1].startswith("decoder."):
if sd3_taesd_dec and sd3_taesd_enc: have_img_decoder.add(parts[0])
vaes.append("taesd3") vaes += [k for k in have_img_decoder if k in have_img_encoder]
if f1_taesd_dec and f1_taesd_enc:
vaes.append("taef1")
vaes.append("pixel_space") vaes.append("pixel_space")
return vaes return vaes
@ -827,6 +803,11 @@ class VAELoader:
else: else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(vae_path, return_metadata=True)
if vae_name == "taef2":
if metadata is None:
metadata = {"tae_latent_channels": 128}
else:
metadata["tae_latent_channels"] = 128
vae = comfy.sd.VAE(sd=sd, metadata=metadata) vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid() vae.throw_exception_if_invalid()
return (vae,) return (vae,)
@ -1716,6 +1697,10 @@ class LoadImage:
def load_image(self, image): def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image) image_path = folder_paths.get_annotated_filepath(image)
components = InputImpl.VideoFromFile(image_path).get_components()
if components.images.shape[0] > 0:
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
img = node_helpers.pillow(Image.open, image_path) img = node_helpers.pillow(Image.open, image_path)
output_images = [] output_images = []
@ -2459,7 +2444,7 @@ async def init_builtin_extra_nodes():
"nodes_curve.py", "nodes_curve.py",
"nodes_rtdetr.py", "nodes_rtdetr.py",
"nodes_frame_interpolation.py", "nodes_frame_interpolation.py",
"nodes_sam3.py" "nodes_sam3.py",
] ]
import_failed = [] import_failed = []

3231
openapi.yaml Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.19.3" version = "0.20.1"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.42.14 comfyui-frontend-package==1.42.15
comfyui-workflow-templates==0.9.59 comfyui-workflow-templates==0.9.63
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.4
torch torch
torchsde torchsde
torchvision torchvision
@ -19,11 +19,11 @@ scipy
tqdm tqdm
psutil psutil
alembic alembic
SQLAlchemy>=2.0 SQLAlchemy>=2.0.0
filelock filelock
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.8 comfy-kitchen>=0.2.8
comfy-aimdo>=0.2.12 comfy-aimdo==0.3.0
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
blake3 blake3