diff --git a/.gitignore b/.gitignore index 2700ad5c2..0ab4ba75e 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,5 @@ venv*/ *.log web_custom_versions/ .DS_Store -openapi.yaml filtered-openapi.yaml uv.lock diff --git a/blueprints/.glsl/Glow_30.frag b/blueprints/.glsl/Glow_30.frag index 0ee152628..f3c85a212 100644 --- a/blueprints/.glsl/Glow_30.frag +++ b/blueprints/.glsl/Glow_30.frag @@ -2,7 +2,6 @@ precision mediump float; uniform sampler2D u_image0; -uniform vec2 u_resolution; uniform int u_int0; // Blend mode uniform int u_int1; // Color tint uniform float u_float0; // Intensity @@ -75,7 +74,7 @@ void main() { float t0 = threshold - 0.15; float t1 = threshold + 0.15; - vec2 texelSize = 1.0 / u_resolution; + vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0)); float radius2 = radius * radius; float sampleScale = clamp(radius * 0.75, 0.35, 1.0); diff --git a/blueprints/.glsl/Image_Blur_1.frag b/blueprints/.glsl/Image_Blur_1.frag index 83238111d..1819e1695 100644 --- a/blueprints/.glsl/Image_Blur_1.frag +++ b/blueprints/.glsl/Image_Blur_1.frag @@ -12,7 +12,6 @@ const int RADIAL_SAMPLES = 12; const float RADIAL_STRENGTH = 0.0003; uniform sampler2D u_image0; -uniform vec2 u_resolution; uniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL) uniform float u_float0; // Blur radius/amount uniform int u_pass; // Pass index (0 = horizontal, 1 = vertical) @@ -25,7 +24,7 @@ float gaussian(float x, float sigma) { } void main() { - vec2 texelSize = 1.0 / u_resolution; + vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0)); float radius = max(u_float0, 0.0); // Radial (angular) blur - single pass, doesn't use separable diff --git a/blueprints/.glsl/Sharpen_23.frag b/blueprints/.glsl/Sharpen_23.frag index c03f94b66..e7463a329 100644 --- a/blueprints/.glsl/Sharpen_23.frag +++ b/blueprints/.glsl/Sharpen_23.frag @@ -2,14 +2,13 @@ precision highp float; uniform sampler2D u_image0; -uniform vec2 u_resolution; uniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0 in vec2 v_texCoord; layout(location = 0) out vec4 fragColor0; void main() { - vec2 texel = 1.0 / u_resolution; + vec2 texel = 1.0 / vec2(textureSize(u_image0, 0)); // Sample center and neighbors vec4 center = texture(u_image0, v_texCoord); diff --git a/blueprints/.glsl/Unsharp_Mask_26.frag b/blueprints/.glsl/Unsharp_Mask_26.frag index f5990cb4a..d968c9c03 100644 --- a/blueprints/.glsl/Unsharp_Mask_26.frag +++ b/blueprints/.glsl/Unsharp_Mask_26.frag @@ -2,7 +2,6 @@ precision highp float; uniform sampler2D u_image0; -uniform vec2 u_resolution; uniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5 uniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels uniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen @@ -19,7 +18,7 @@ float getLuminance(vec3 color) { } void main() { - vec2 texel = 1.0 / u_resolution; + vec2 texel = 1.0 / vec2(textureSize(u_image0, 0)); float radius = max(u_float1, 0.5); float amount = u_float0; float threshold = u_float2; diff --git a/blueprints/Glow.json b/blueprints/Glow.json index 8c690fc68..1dafb2d35 100644 --- a/blueprints/Glow.json +++ b/blueprints/Glow.json @@ -268,7 +268,7 @@ "Node name for S&R": "GLSLShader" }, "widgets_values": [ - "#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / u_resolution;\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}", + "#version 300 es\nprecision mediump float;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blend mode\nuniform int u_int1; // Color tint\nuniform float u_float0; // Intensity\nuniform float u_float1; // Radius\nuniform float u_float2; // Threshold\n\nin vec2 v_texCoord;\nout vec4 fragColor;\n\nconst int BLEND_ADD = 0;\nconst int BLEND_SCREEN = 1;\nconst int BLEND_SOFT = 2;\nconst int BLEND_OVERLAY = 3;\nconst int BLEND_LIGHTEN = 4;\n\nconst float GOLDEN_ANGLE = 2.39996323;\nconst int MAX_SAMPLES = 48;\nconst vec3 LUMA = vec3(0.299, 0.587, 0.114);\n\nfloat hash(vec2 p) {\n p = fract(p * vec2(123.34, 456.21));\n p += dot(p, p + 45.32);\n return fract(p.x * p.y);\n}\n\nvec3 hexToRgb(int h) {\n return vec3(\n float((h >> 16) & 255),\n float((h >> 8) & 255),\n float(h & 255)\n ) * (1.0 / 255.0);\n}\n\nvec3 blend(vec3 base, vec3 glow, int mode) {\n if (mode == BLEND_SCREEN) {\n return 1.0 - (1.0 - base) * (1.0 - glow);\n }\n if (mode == BLEND_SOFT) {\n return mix(\n base - (1.0 - 2.0 * glow) * base * (1.0 - base),\n base + (2.0 * glow - 1.0) * (sqrt(base) - base),\n step(0.5, glow)\n );\n }\n if (mode == BLEND_OVERLAY) {\n return mix(\n 2.0 * base * glow,\n 1.0 - 2.0 * (1.0 - base) * (1.0 - glow),\n step(0.5, base)\n );\n }\n if (mode == BLEND_LIGHTEN) {\n return max(base, glow);\n }\n return base + glow;\n}\n\nvoid main() {\n vec4 original = texture(u_image0, v_texCoord);\n \n float intensity = u_float0 * 0.05;\n float radius = u_float1 * u_float1 * 0.012;\n \n if (intensity < 0.001 || radius < 0.1) {\n fragColor = original;\n return;\n }\n \n float threshold = 1.0 - u_float2 * 0.01;\n float t0 = threshold - 0.15;\n float t1 = threshold + 0.15;\n \n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius2 = radius * radius;\n \n float sampleScale = clamp(radius * 0.75, 0.35, 1.0);\n int samples = int(float(MAX_SAMPLES) * sampleScale);\n \n float noise = hash(gl_FragCoord.xy);\n float angleOffset = noise * GOLDEN_ANGLE;\n float radiusJitter = 0.85 + noise * 0.3;\n \n float ca = cos(GOLDEN_ANGLE);\n float sa = sin(GOLDEN_ANGLE);\n vec2 dir = vec2(cos(angleOffset), sin(angleOffset));\n \n vec3 glow = vec3(0.0);\n float totalWeight = 0.0;\n \n // Center tap\n float centerMask = smoothstep(t0, t1, dot(original.rgb, LUMA));\n glow += original.rgb * centerMask * 2.0;\n totalWeight += 2.0;\n \n for (int i = 1; i < MAX_SAMPLES; i++) {\n if (i >= samples) break;\n \n float fi = float(i);\n float dist = sqrt(fi / float(samples)) * radius * radiusJitter;\n \n vec2 offset = dir * dist * texelSize;\n vec3 c = texture(u_image0, v_texCoord + offset).rgb;\n float mask = smoothstep(t0, t1, dot(c, LUMA));\n \n float w = 1.0 - (dist * dist) / (radius2 * 1.5);\n w = max(w, 0.0);\n w *= w;\n \n glow += c * mask * w;\n totalWeight += w;\n \n dir = vec2(\n dir.x * ca - dir.y * sa,\n dir.x * sa + dir.y * ca\n );\n }\n \n glow *= intensity / max(totalWeight, 0.001);\n \n if (u_int1 > 0) {\n glow *= hexToRgb(u_int1);\n }\n \n vec3 result = blend(original.rgb, glow, u_int0);\n result += (noise - 0.5) * (1.0 / 255.0);\n \n fragColor = vec4(clamp(result, 0.0, 1.0), original.a);\n}", "from_input" ] }, diff --git a/blueprints/Image Blur.json b/blueprints/Image Blur.json index b1d449e32..3c7a784b0 100644 --- a/blueprints/Image Blur.json +++ b/blueprints/Image Blur.json @@ -331,7 +331,7 @@ "Node name for S&R": "GLSLShader" }, "widgets_values": [ - "#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / u_resolution;\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n", + "#version 300 es\n#pragma passes 2\nprecision highp float;\n\n// Blur type constants\nconst int BLUR_GAUSSIAN = 0;\nconst int BLUR_BOX = 1;\nconst int BLUR_RADIAL = 2;\n\n// Radial blur config\nconst int RADIAL_SAMPLES = 12;\nconst float RADIAL_STRENGTH = 0.0003;\n\nuniform sampler2D u_image0;\nuniform int u_int0; // Blur type (BLUR_GAUSSIAN, BLUR_BOX, BLUR_RADIAL)\nuniform float u_float0; // Blur radius/amount\nuniform int u_pass; // Pass index (0 = horizontal, 1 = vertical)\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nvoid main() {\n vec2 texelSize = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float0, 0.0);\n\n // Radial (angular) blur - single pass, doesn't use separable\n if (u_int0 == BLUR_RADIAL) {\n // Only execute on first pass\n if (u_pass > 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec2 center = vec2(0.5);\n vec2 dir = v_texCoord - center;\n float dist = length(dir);\n\n if (dist < 1e-4) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n vec4 sum = vec4(0.0);\n float totalWeight = 0.0;\n float angleStep = radius * RADIAL_STRENGTH;\n\n dir /= dist;\n\n float cosStep = cos(angleStep);\n float sinStep = sin(angleStep);\n\n float negAngle = -float(RADIAL_SAMPLES) * angleStep;\n vec2 rotDir = vec2(\n dir.x * cos(negAngle) - dir.y * sin(negAngle),\n dir.x * sin(negAngle) + dir.y * cos(negAngle)\n );\n\n for (int i = -RADIAL_SAMPLES; i <= RADIAL_SAMPLES; i++) {\n vec2 uv = center + rotDir * dist;\n float w = 1.0 - abs(float(i)) / float(RADIAL_SAMPLES);\n sum += texture(u_image0, uv) * w;\n totalWeight += w;\n\n rotDir = vec2(\n rotDir.x * cosStep - rotDir.y * sinStep,\n rotDir.x * sinStep + rotDir.y * cosStep\n );\n }\n\n fragColor0 = sum / max(totalWeight, 0.001);\n return;\n }\n\n // Separable Gaussian / Box blur\n int samples = int(ceil(radius));\n\n if (samples == 0) {\n fragColor0 = texture(u_image0, v_texCoord);\n return;\n }\n\n // Direction: pass 0 = horizontal, pass 1 = vertical\n vec2 dir = (u_pass == 0) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);\n\n vec4 color = vec4(0.0);\n float totalWeight = 0.0;\n float sigma = radius / 2.0;\n\n for (int i = -samples; i <= samples; i++) {\n vec2 offset = dir * float(i) * texelSize;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float weight;\n if (u_int0 == BLUR_GAUSSIAN) {\n weight = gaussian(float(i), sigma);\n } else {\n // BLUR_BOX\n weight = 1.0;\n }\n\n color += sample_color * weight;\n totalWeight += weight;\n }\n\n fragColor0 = color / totalWeight;\n}\n", "from_input" ] } diff --git a/blueprints/Sharpen.json b/blueprints/Sharpen.json index bb79f61fc..f332400fd 100644 --- a/blueprints/Sharpen.json +++ b/blueprints/Sharpen.json @@ -267,7 +267,7 @@ "Node name for S&R": "GLSLShader" }, "widgets_values": [ - "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", + "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // strength [0.0 – 2.0] typical: 0.3–1.0\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n \n // Sample center and neighbors\n vec4 center = texture(u_image0, v_texCoord);\n vec4 top = texture(u_image0, v_texCoord + vec2( 0.0, -texel.y));\n vec4 bottom = texture(u_image0, v_texCoord + vec2( 0.0, texel.y));\n vec4 left = texture(u_image0, v_texCoord + vec2(-texel.x, 0.0));\n vec4 right = texture(u_image0, v_texCoord + vec2( texel.x, 0.0));\n \n // Edge enhancement (Laplacian)\n vec4 edges = center * 4.0 - top - bottom - left - right;\n \n // Add edges back scaled by strength\n vec4 sharpened = center + edges * u_float0;\n \n fragColor0 = vec4(clamp(sharpened.rgb, 0.0, 1.0), center.a);\n}", "from_input" ] } diff --git a/blueprints/Unsharp Mask.json b/blueprints/Unsharp Mask.json index b673eb703..137acaa43 100644 --- a/blueprints/Unsharp Mask.json +++ b/blueprints/Unsharp Mask.json @@ -383,7 +383,7 @@ "Node name for S&R": "GLSLShader" }, "widgets_values": [ - "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform vec2 u_resolution;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / u_resolution;\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n", + "#version 300 es\nprecision highp float;\n\nuniform sampler2D u_image0;\nuniform float u_float0; // amount [0.0 - 3.0] typical: 0.5-1.5\nuniform float u_float1; // radius [0.5 - 10.0] blur radius in pixels\nuniform float u_float2; // threshold [0.0 - 0.1] min difference to sharpen\n\nin vec2 v_texCoord;\nlayout(location = 0) out vec4 fragColor0;\n\nfloat gaussian(float x, float sigma) {\n return exp(-(x * x) / (2.0 * sigma * sigma));\n}\n\nfloat getLuminance(vec3 color) {\n return dot(color, vec3(0.2126, 0.7152, 0.0722));\n}\n\nvoid main() {\n vec2 texel = 1.0 / vec2(textureSize(u_image0, 0));\n float radius = max(u_float1, 0.5);\n float amount = u_float0;\n float threshold = u_float2;\n\n vec4 original = texture(u_image0, v_texCoord);\n\n // Gaussian blur for the \"unsharp\" mask\n int samples = int(ceil(radius));\n float sigma = radius / 2.0;\n\n vec4 blurred = vec4(0.0);\n float totalWeight = 0.0;\n\n for (int x = -samples; x <= samples; x++) {\n for (int y = -samples; y <= samples; y++) {\n vec2 offset = vec2(float(x), float(y)) * texel;\n vec4 sample_color = texture(u_image0, v_texCoord + offset);\n\n float dist = length(vec2(float(x), float(y)));\n float weight = gaussian(dist, sigma);\n blurred += sample_color * weight;\n totalWeight += weight;\n }\n }\n blurred /= totalWeight;\n\n // Unsharp mask = original - blurred\n vec3 mask = original.rgb - blurred.rgb;\n\n // Luminance-based threshold with smooth falloff\n float lumaDelta = abs(getLuminance(original.rgb) - getLuminance(blurred.rgb));\n float thresholdScale = smoothstep(0.0, threshold, lumaDelta);\n mask *= thresholdScale;\n\n // Sharpen: original + mask * amount\n vec3 sharpened = original.rgb + mask * amount;\n\n fragColor0 = vec4(clamp(sharpened, 0.0, 1.0), original.a);\n}\n", "from_input" ] } diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index fa0a00748..dd5320c8f 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -4,9 +4,6 @@ import math import torch import torchaudio -import comfy.model_management -import comfy.model_patcher -import comfy.utils as utils from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( @@ -43,30 +40,6 @@ class AudioVAEComponentConfig: return cls(autoencoder=audio_config, vocoder=vocoder_config) - -class ModelDeviceManager: - """Manages device placement and GPU residency for the composed model.""" - - def __init__(self, module: torch.nn.Module): - load_device = comfy.model_management.get_torch_device() - offload_device = comfy.model_management.vae_offload_device() - self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) - - def ensure_model_loaded(self) -> None: - comfy.model_management.free_memory( - self.patcher.model_size(), - self.patcher.load_device, - ) - comfy.model_management.load_model_gpu(self.patcher) - - def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: - return tensor.to(self.patcher.load_device) - - @property - def load_device(self): - return self.patcher.load_device - - class AudioLatentNormalizer: """Applies per-channel statistics in patch space and restores original layout.""" @@ -132,23 +105,17 @@ class AudioPreprocessor: class AudioVAE(torch.nn.Module): """High-level Audio VAE wrapper exposing encode and decode entry points.""" - def __init__(self, state_dict: dict, metadata: dict): + def __init__(self, metadata: dict): super().__init__() component_config = AudioVAEComponentConfig.from_metadata(metadata) - vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) - vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) - self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) if "bwe" in component_config.vocoder: self.vocoder = VocoderWithBWE(config=component_config.vocoder) else: self.vocoder = Vocoder(config=component_config.vocoder) - self.autoencoder.load_state_dict(vae_sd, strict=False) - self.vocoder.load_state_dict(vocoder_sd, strict=False) - autoencoder_config = self.autoencoder.get_config() self.normalizer = AudioLatentNormalizer( AudioPatchifier( @@ -168,18 +135,12 @@ class AudioVAE(torch.nn.Module): n_fft=autoencoder_config["n_fft"], ) - self.device_manager = ModelDeviceManager(self) - - def encode(self, audio: dict) -> torch.Tensor: + def encode(self, audio, sample_rate=44100) -> torch.Tensor: """Encode a waveform dictionary into normalized latent tensors.""" - waveform = audio["waveform"] - waveform_sample_rate = audio["sample_rate"] + waveform = audio + waveform_sample_rate = sample_rate input_device = waveform.device - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - waveform = self.device_manager.move_to_load_device(waveform) expected_channels = self.autoencoder.encoder.in_channels if waveform.shape[1] != expected_channels: if waveform.shape[1] == 1: @@ -190,7 +151,7 @@ class AudioVAE(torch.nn.Module): ) mel_spec = self.preprocessor.waveform_to_mel( - waveform, waveform_sample_rate, device=self.device_manager.load_device + waveform, waveform_sample_rate, device=waveform.device ) latents = self.autoencoder.encode(mel_spec) @@ -204,17 +165,13 @@ class AudioVAE(torch.nn.Module): """Decode normalized latent tensors into an audio waveform.""" original_shape = latents.shape - # Ensure that Audio VAE is loaded on the correct device. - self.device_manager.ensure_model_loaded() - - latents = self.device_manager.move_to_load_device(latents) latents = self.normalizer.denormalize(latents) target_shape = self.target_shape_from_latents(original_shape) mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) waveform = self.run_vocoder(mel_spec) - return self.device_manager.move_to_load_device(waveform) + return waveform def target_shape_from_latents(self, latents_shape): batch, _, time, _ = latents_shape diff --git a/comfy/ldm/sam3/detector.py b/comfy/ldm/sam3/detector.py new file mode 100644 index 000000000..12d3a01ab --- /dev/null +++ b/comfy/ldm/sam3/detector.py @@ -0,0 +1,596 @@ +# SAM3 detector: transformer encoder-decoder, segmentation head, geometry encoder, scoring. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import roi_align + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.sam3.tracker import SAM3Tracker, SAM31Tracker +from comfy.ldm.sam3.sam import SAM3VisionBackbone # noqa: used in __init__ +from comfy.ldm.sam3.sam import MLP, PositionEmbeddingSine + +TRACKER_CLASSES = {"SAM3": SAM3Tracker, "SAM31": SAM31Tracker} +from comfy.ops import cast_to_input + + +def box_cxcywh_to_xyxy(x): + cx, cy, w, h = x.unbind(-1) + return torch.stack([cx - 0.5 * w, cy - 0.5 * h, cx + 0.5 * w, cy + 0.5 * h], dim=-1) + + +def gen_sineembed_for_position(pos_tensor, num_feats=256): + """Per-coordinate sinusoidal embedding: (..., N) -> (..., N * num_feats).""" + assert num_feats % 2 == 0 + hdim = num_feats // 2 + freqs = 10000.0 ** (2 * (torch.arange(hdim, dtype=torch.float32, device=pos_tensor.device) // 2) / hdim) + embeds = [] + for c in range(pos_tensor.shape[-1]): + raw = (pos_tensor[..., c].float() * 2 * math.pi).unsqueeze(-1) / freqs + embeds.append(torch.stack([raw[..., 0::2].sin(), raw[..., 1::2].cos()], dim=-1).flatten(-2)) + return torch.cat(embeds, dim=-1).to(pos_tensor.dtype) + + +class SplitMHA(nn.Module): + """Multi-head attention with separate Q/K/V projections (split from fused in_proj_weight).""" + def __init__(self, d_model, num_heads=8, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + + def forward(self, q_input, k_input=None, v_input=None, mask=None): + q = self.q_proj(q_input) + if k_input is None: + k = self.k_proj(q_input) + v = self.v_proj(q_input) + else: + k = self.k_proj(k_input) + v = self.v_proj(v_input if v_input is not None else k_input) + if mask is not None and mask.ndim == 2: + mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast + dtype = q.dtype # manual_cast may produce mixed dtypes + out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False) + return self.out_proj(out) + + +class MLPWithNorm(nn.Module): + """MLP with residual connection and output LayerNorm.""" + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, residual=True, device=None, dtype=None, operations=None): + super().__init__() + dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.layers = nn.ModuleList([ + operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) + for i in range(num_layers) + ]) + self.out_norm = operations.LayerNorm(output_dim, device=device, dtype=dtype) + self.residual = residual and (input_dim == output_dim) + + def forward(self, x): + orig = x + for i, layer in enumerate(self.layers): + x = layer(x) + if i < len(self.layers) - 1: + x = F.relu(x) + if self.residual: + x = x + orig + return self.out_norm(x) + + +class EncoderLayer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_image = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, x, pos, text_memory=None, text_mask=None): + normed = self.norm1(x) + q_k = normed + pos + x = x + self.self_attn(q_k, q_k, normed) + if text_memory is not None: + normed = self.norm2(x) + x = x + self.cross_attn_image(normed, text_memory, text_memory, mask=text_mask) + normed = self.norm3(x) + x = x + self.linear2(F.relu(self.linear1(normed))) + return x + + +class TransformerEncoder(nn.Module): + """Checkpoint: transformer.encoder.layers.N.*""" + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + EncoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + + def forward(self, x, pos, text_memory=None, text_mask=None): + for layer in self.layers: + x = layer(x, pos, text_memory, text_mask) + return x + + +class DecoderLayer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.self_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.ca_text = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.catext_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + + def forward(self, x, memory, x_pos, memory_pos, text_memory=None, text_mask=None, cross_attn_bias=None): + q_k = x + x_pos + x = self.norm2(x + self.self_attn(q_k, q_k, x)) + if text_memory is not None: + x = self.catext_norm(x + self.ca_text(x + x_pos, text_memory, text_memory, mask=text_mask)) + x = self.norm1(x + self.cross_attn(x + x_pos, memory + memory_pos, memory, mask=cross_attn_bias)) + x = self.norm3(x + self.linear2(F.relu(self.linear1(x)))) + return x + + +class TransformerDecoder(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, num_layers=6, + num_queries=200, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.num_queries = num_queries + + self.layers = nn.ModuleList([ + DecoderLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.query_embed = operations.Embedding(num_queries, d_model, device=device, dtype=dtype) + self.reference_points = operations.Embedding(num_queries, 4, device=device, dtype=dtype) # Reference points: Embedding(num_queries, 4) — learned anchor boxes + self.ref_point_head = MLP(d_model * 2, d_model, d_model, 2, device=device, dtype=dtype, operations=operations) # ref_point_head input: 512 (4 coords * 128 sine features each) + self.bbox_embed = MLP(d_model, d_model, 4, 3, device=device, dtype=dtype, operations=operations) + + self.boxRPB_embed_x = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations) + self.boxRPB_embed_y = MLP(2, d_model, num_heads, 2, device=device, dtype=dtype, operations=operations) + + self.presence_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.presence_token_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + self.presence_token_out_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + + @staticmethod + def _inverse_sigmoid(x): + return torch.log(x / (1 - x + 1e-6) + 1e-6) + + def _compute_box_rpb(self, ref_points, H, W): + """Box rotary position bias: (B, Q, 4) cxcywh -> (B, n_heads, Q+1, H*W) bias.""" + boxes_xyxy = box_cxcywh_to_xyxy(ref_points) + B, Q, _ = boxes_xyxy.shape + coords_h = torch.arange(H, device=ref_points.device, dtype=torch.float32) / H + coords_w = torch.arange(W, device=ref_points.device, dtype=torch.float32) / W + deltas_x = coords_w.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 0:3:2] + deltas_y = coords_h.view(1, 1, -1, 1) - boxes_xyxy[:, :, None, 1:4:2] + + log2_8 = float(math.log2(8)) + def log_scale(d): + return torch.sign(d * 8) * torch.log2(torch.abs(d * 8) + 1.0) / log2_8 + + rpb_x = self.boxRPB_embed_x(log_scale(deltas_x).to(ref_points.dtype)) + rpb_y = self.boxRPB_embed_y(log_scale(deltas_y).to(ref_points.dtype)) + + bias = (rpb_y.unsqueeze(3) + rpb_x.unsqueeze(2)).flatten(2, 3).permute(0, 3, 1, 2) + pres_bias = torch.zeros(B, bias.shape[1], 1, bias.shape[3], device=bias.device, dtype=bias.dtype) + return torch.cat([pres_bias, bias], dim=2) + + def forward(self, memory, memory_pos, text_memory=None, text_mask=None, H=72, W=72): + B = memory.shape[0] + tgt = cast_to_input(self.query_embed.weight, memory).unsqueeze(0).expand(B, -1, -1) + presence_out = cast_to_input(self.presence_token.weight, memory)[None].expand(B, -1, -1) + ref_points = cast_to_input(self.reference_points.weight, memory).unsqueeze(0).expand(B, -1, -1).sigmoid() + + for layer_idx, layer in enumerate(self.layers): + query_pos = self.ref_point_head(gen_sineembed_for_position(ref_points, self.d_model)) + tgt_with_pres = torch.cat([presence_out, tgt], dim=1) + pos_with_pres = torch.cat([torch.zeros_like(presence_out), query_pos], dim=1) + tgt_with_pres = layer(tgt_with_pres, memory, pos_with_pres, memory_pos, + text_memory, text_mask, self._compute_box_rpb(ref_points, H, W)) + presence_out, tgt = tgt_with_pres[:, :1], tgt_with_pres[:, 1:] + if layer_idx < len(self.layers) - 1: + ref_inv = self._inverse_sigmoid(ref_points) + ref_points = (ref_inv + self.bbox_embed(self.norm(tgt))).sigmoid().detach() + + query_out = self.norm(tgt) + ref_inv = self._inverse_sigmoid(ref_points) + boxes = (ref_inv + self.bbox_embed(query_out)).sigmoid() + presence = self.presence_token_head(self.presence_token_out_norm(presence_out)).squeeze(-1) + return {"decoder_output": query_out, "pred_boxes": boxes, "presence": presence} + + +class Transformer(nn.Module): + def __init__(self, d_model=256, num_heads=8, dim_ff=2048, enc_layers=6, dec_layers=6, + num_queries=200, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = TransformerEncoder(d_model, num_heads, dim_ff, enc_layers, device=device, dtype=dtype, operations=operations) + self.decoder = TransformerDecoder(d_model, num_heads, dim_ff, dec_layers, num_queries, device=device, dtype=dtype, operations=operations) + + +class GeometryEncoder(nn.Module): + def __init__(self, d_model=256, num_heads=8, num_layers=3, roi_size=7, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.roi_size = roi_size + self.pos_enc = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True) + self.points_direct_project = operations.Linear(2, d_model, device=device, dtype=dtype) + self.points_pool_project = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.points_pos_enc_project = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.boxes_direct_project = operations.Linear(4, d_model, device=device, dtype=dtype) + self.boxes_pool_project = operations.Conv2d(d_model, d_model, kernel_size=roi_size, device=device, dtype=dtype) + self.boxes_pos_enc_project = operations.Linear(d_model + 2, d_model, device=device, dtype=dtype) + self.label_embed = operations.Embedding(2, d_model, device=device, dtype=dtype) + self.cls_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.img_pre_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.encode = nn.ModuleList([ + EncoderLayer(d_model, num_heads, 2048, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.encode_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.final_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + + def _encode_points(self, coords, labels, img_feat_2d): + """Encode point prompts: direct + pool + pos_enc + label. coords: [B, N, 2] normalized.""" + B, N, _ = coords.shape + embed = self.points_direct_project(coords) + # Pool features from backbone at point locations via grid_sample + grid = (coords * 2 - 1).unsqueeze(2) # [B, N, 1, 2] in [-1, 1] + sampled = F.grid_sample(img_feat_2d, grid, align_corners=False) # [B, C, N, 1] + embed = embed + self.points_pool_project(sampled.squeeze(-1).permute(0, 2, 1)) # [B, N, C] + # Positional encoding of coordinates + x, y = coords[:, :, 0], coords[:, :, 1] # [B, N] + pos_x, pos_y = self.pos_enc._encode_xy(x.flatten(), y.flatten()) + enc = torch.cat([pos_x, pos_y], dim=-1).view(B, N, -1) + embed = embed + self.points_pos_enc_project(cast_to_input(enc, embed)) + embed = embed + cast_to_input(self.label_embed(labels.long()), embed) + return embed + + def _encode_boxes(self, boxes, labels, img_feat_2d): + """Encode box prompts: direct + pool + pos_enc + label. boxes: [B, N, 4] normalized cxcywh.""" + B, N, _ = boxes.shape + embed = self.boxes_direct_project(boxes) + # ROI align from backbone at box regions + H, W = img_feat_2d.shape[-2:] + boxes_xyxy = box_cxcywh_to_xyxy(boxes) + scale = torch.tensor([W, H, W, H], dtype=boxes_xyxy.dtype, device=boxes_xyxy.device) + boxes_scaled = boxes_xyxy * scale + sampled = roi_align(img_feat_2d, boxes_scaled.view(-1, 4).split(N), self.roi_size) + proj = self.boxes_pool_project(sampled).view(B, N, -1) # Conv2d(roi_size) -> [B*N, C, 1, 1] -> [B, N, C] + embed = embed + proj + # Positional encoding of box center + size + cx, cy, w, h = boxes[:, :, 0], boxes[:, :, 1], boxes[:, :, 2], boxes[:, :, 3] + enc = self.pos_enc.encode_boxes(cx.flatten(), cy.flatten(), w.flatten(), h.flatten()) + enc = enc.view(B, N, -1) + embed = embed + self.boxes_pos_enc_project(cast_to_input(enc, embed)) + embed = embed + cast_to_input(self.label_embed(labels.long()), embed) + return embed + + def forward(self, points=None, boxes=None, image_features=None): + """Encode geometry prompts. image_features: [B, HW, C] flattened backbone features.""" + # Prepare 2D image features for pooling + img_feat_2d = None + if image_features is not None: + B = image_features.shape[0] + HW, C = image_features.shape[1], image_features.shape[2] + hw = int(math.sqrt(HW)) + img_normed = self.img_pre_norm(image_features) + img_feat_2d = img_normed.permute(0, 2, 1).view(B, C, hw, hw) + + embeddings = [] + if points is not None: + coords, labels = points + embeddings.append(self._encode_points(coords, labels, img_feat_2d)) + if boxes is not None: + B = boxes.shape[0] + box_labels = torch.ones(B, boxes.shape[1], dtype=torch.long, device=boxes.device) + embeddings.append(self._encode_boxes(boxes, box_labels, img_feat_2d)) + if not embeddings: + return None + geo = torch.cat(embeddings, dim=1) + geo = self.norm(geo) + if image_features is not None: + for layer in self.encode: + geo = layer(geo, torch.zeros_like(geo), image_features) + geo = self.encode_norm(geo) + return self.final_proj(geo) + + +class PixelDecoder(nn.Module): + """Top-down FPN pixel decoder with GroupNorm + ReLU + nearest interpolation.""" + def __init__(self, d_model=256, num_stages=3, device=None, dtype=None, operations=None): + super().__init__() + self.conv_layers = nn.ModuleList([operations.Conv2d(d_model, d_model, kernel_size=3, padding=1, device=device, dtype=dtype) for _ in range(num_stages)]) + self.norms = nn.ModuleList([operations.GroupNorm(8, d_model, device=device, dtype=dtype) for _ in range(num_stages)]) + + def forward(self, backbone_features): + prev = backbone_features[-1] + for i, feat in enumerate(backbone_features[:-1][::-1]): + prev = F.relu(self.norms[i](self.conv_layers[i](feat + F.interpolate(prev, size=feat.shape[-2:], mode="nearest")))) + return prev + + +class MaskPredictor(nn.Module): + def __init__(self, d_model=256, device=None, dtype=None, operations=None): + super().__init__() + self.mask_embed = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, query_embeddings, pixel_features): + mask_embed = self.mask_embed(query_embeddings) + return torch.einsum("bqc,bchw->bqhw", mask_embed, pixel_features) + + +class SegmentationHead(nn.Module): + def __init__(self, d_model=256, num_heads=8, device=None, dtype=None, operations=None): + super().__init__() + self.d_model = d_model + self.pixel_decoder = PixelDecoder(d_model, 3, device=device, dtype=dtype, operations=operations) + self.mask_predictor = MaskPredictor(d_model, device=device, dtype=dtype, operations=operations) + self.cross_attend_prompt = SplitMHA(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.instance_seg_head = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype) + self.semantic_seg_head = operations.Conv2d(d_model, 1, kernel_size=1, device=device, dtype=dtype) + + def forward(self, query_embeddings, backbone_features, encoder_hidden_states=None, prompt=None, prompt_mask=None): + if encoder_hidden_states is not None and prompt is not None: + enc_normed = self.cross_attn_norm(encoder_hidden_states) + enc_cross = self.cross_attend_prompt(enc_normed, prompt, prompt, mask=prompt_mask) + encoder_hidden_states = enc_cross + encoder_hidden_states + + if encoder_hidden_states is not None: + B, H, W = encoder_hidden_states.shape[0], backbone_features[-1].shape[-2], backbone_features[-1].shape[-1] + encoder_visual = encoder_hidden_states[:, :H * W].permute(0, 2, 1).view(B, self.d_model, H, W) + backbone_features = list(backbone_features) + backbone_features[-1] = encoder_visual + + pixel_features = self.pixel_decoder(backbone_features) + instance_features = self.instance_seg_head(pixel_features) + masks = self.mask_predictor(query_embeddings, instance_features) + return masks + + +class DotProductScoring(nn.Module): + def __init__(self, d_model=256, device=None, dtype=None, operations=None): + super().__init__() + self.hs_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.prompt_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.prompt_mlp = MLPWithNorm(d_model, 2048, d_model, 2, device=device, dtype=dtype, operations=operations) + self.scale = 1.0 / (d_model ** 0.5) + + def forward(self, query_embeddings, prompt_embeddings, prompt_mask=None): + prompt = self.prompt_mlp(prompt_embeddings) + if prompt_mask is not None: + weight = prompt_mask.unsqueeze(-1).to(dtype=prompt.dtype) + pooled = (prompt * weight).sum(dim=1) / weight.sum(dim=1).clamp(min=1) + else: + pooled = prompt.mean(dim=1) + hs = self.hs_proj(query_embeddings) + pp = self.prompt_proj(pooled).unsqueeze(-1).to(hs.dtype) + scores = torch.matmul(hs, pp) + return (scores * self.scale).clamp(-12.0, 12.0).squeeze(-1) + + +class SAM3Detector(nn.Module): + def __init__(self, d_model=256, embed_dim=1024, num_queries=200, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + image_model = kwargs.pop("image_model", "SAM3") + for k in ("num_heads", "num_head_channels"): + kwargs.pop(k, None) + multiplex = image_model == "SAM31" + # SAM3: 4 FPN levels, drop last (scalp=1); SAM3.1: 3 levels, use all (scalp=0) + self.scalp = 0 if multiplex else 1 + self.backbone = nn.ModuleDict({ + "vision_backbone": SAM3VisionBackbone(embed_dim=embed_dim, d_model=d_model, multiplex=multiplex, device=device, dtype=dtype, operations=operations, **kwargs), + "language_backbone": nn.ModuleDict({"resizer": operations.Linear(embed_dim, d_model, device=device, dtype=dtype)}), + }) + self.transformer = Transformer(d_model=d_model, num_queries=num_queries, device=device, dtype=dtype, operations=operations) + self.segmentation_head = SegmentationHead(d_model=d_model, device=device, dtype=dtype, operations=operations) + self.geometry_encoder = GeometryEncoder(d_model=d_model, device=device, dtype=dtype, operations=operations) + self.dot_prod_scoring = DotProductScoring(d_model=d_model, device=device, dtype=dtype, operations=operations) + + def _get_backbone_features(self, images): + """Run backbone and return (detector_features, detector_positions, tracker_features, tracker_positions).""" + bb = self.backbone["vision_backbone"] + if bb.multiplex: + all_f, all_p, tf, tp = bb(images, tracker_mode="propagation") + else: + all_f, all_p, tf, tp = bb(images, need_tracker=True) + return all_f, all_p, tf, tp + + @staticmethod + def _run_geo_layer(layer, x, memory, memory_pos): + x = x + layer.self_attn(layer.norm1(x)) + x = x + layer.cross_attn_image(layer.norm2(x), memory + memory_pos, memory) + x = x + layer.linear2(F.relu(layer.linear1(layer.norm3(x)))) + return x + + def _detect(self, features, positions, text_embeddings=None, text_mask=None, + points=None, boxes=None): + """Shared detection: geometry encoding, transformer, scoring, segmentation.""" + B = features[0].shape[0] + # Scalp for encoder (use top-level feature), but keep all levels for segmentation head + seg_features = features + if self.scalp > 0: + features = features[:-self.scalp] + positions = positions[:-self.scalp] + enc_feat, enc_pos = features[-1], positions[-1] + _, _, H, W = enc_feat.shape + img_flat = enc_feat.flatten(2).permute(0, 2, 1) + pos_flat = enc_pos.flatten(2).permute(0, 2, 1) + + has_prompts = text_embeddings is not None or points is not None or boxes is not None + if has_prompts: + geo_enc = self.geometry_encoder + geo_prompts = geo_enc(points=points, boxes=boxes, image_features=img_flat) + geo_cls = geo_enc.norm(geo_enc.final_proj(cast_to_input(geo_enc.cls_embed.weight, img_flat).view(1, 1, -1).expand(B, -1, -1))) + for layer in geo_enc.encode: + geo_cls = self._run_geo_layer(layer, geo_cls, img_flat, pos_flat) + geo_cls = geo_enc.encode_norm(geo_cls) + if text_embeddings is not None and text_embeddings.shape[0] != B: + text_embeddings = text_embeddings.expand(B, -1, -1) + if text_mask is not None and text_mask.shape[0] != B: + text_mask = text_mask.expand(B, -1) + parts = [t for t in [text_embeddings, geo_prompts, geo_cls] if t is not None] + text_embeddings = torch.cat(parts, dim=1) + n_new = text_embeddings.shape[1] - (text_mask.shape[1] if text_mask is not None else 0) + if text_mask is not None: + text_mask = torch.cat([text_mask, torch.ones(B, n_new, dtype=torch.bool, device=text_mask.device)], dim=1) + else: + text_mask = torch.ones(B, text_embeddings.shape[1], dtype=torch.bool, device=text_embeddings.device) + + memory = self.transformer.encoder(img_flat, pos_flat, text_embeddings, text_mask) + dec_out = self.transformer.decoder(memory, pos_flat, text_embeddings, text_mask, H, W) + query_out, pred_boxes = dec_out["decoder_output"], dec_out["pred_boxes"] + + if text_embeddings is not None: + scores = self.dot_prod_scoring(query_out, text_embeddings, text_mask) + else: + scores = torch.zeros(B, query_out.shape[1], device=query_out.device) + + masks = self.segmentation_head(query_out, seg_features, encoder_hidden_states=memory, prompt=text_embeddings, prompt_mask=text_mask) + return box_cxcywh_to_xyxy(pred_boxes), scores, masks, dec_out + + def forward(self, images, text_embeddings=None, text_mask=None, points=None, boxes=None, threshold=0.3, orig_size=None): + features, positions, _, _ = self._get_backbone_features(images) + + if text_embeddings is not None: + text_embeddings = self.backbone["language_backbone"]["resizer"](text_embeddings) + if text_mask is not None: + text_mask = text_mask.bool() + + boxes_xyxy, scores, masks, dec_out = self._detect( + features, positions, text_embeddings, text_mask, points, boxes) + + if orig_size is not None: + oh, ow = orig_size + boxes_xyxy = boxes_xyxy * torch.tensor([ow, oh, ow, oh], device=boxes_xyxy.device, dtype=boxes_xyxy.dtype) + masks = F.interpolate(masks, size=orig_size, mode="bilinear", align_corners=False) + + return { + "boxes": boxes_xyxy, + "scores": scores, + "masks": masks, + "presence": dec_out.get("presence"), + } + + def forward_from_trunk(self, trunk_out, text_embeddings, text_mask): + """Run detection using a pre-computed ViTDet trunk output. + + text_embeddings must already be resized through language_backbone.resizer. + Returns dict with boxes (normalized xyxy), scores, masks at detector resolution. + """ + bb = self.backbone["vision_backbone"] + features = [conv(trunk_out) for conv in bb.convs] + positions = [cast_to_input(bb.position_encoding(f), f) for f in features] + + if text_mask is not None: + text_mask = text_mask.bool() + + boxes_xyxy, scores, masks, _ = self._detect(features, positions, text_embeddings, text_mask) + return {"boxes": boxes_xyxy, "scores": scores, "masks": masks} + + +class SAM3Model(nn.Module): + def __init__(self, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + image_model = kwargs.get("image_model", "SAM3") + tracker_cls = TRACKER_CLASSES[image_model] + self.detector = SAM3Detector(device=device, dtype=dtype, operations=operations, **kwargs) + self.tracker = tracker_cls(device=device, dtype=dtype, operations=operations, **kwargs) + + def forward(self, images, **kwargs): + return self.detector(images, **kwargs) + + def forward_segment(self, images, point_inputs=None, box_inputs=None, mask_inputs=None): + """Interactive segmentation using SAM decoder with point/box/mask prompts. + + Args: + images: [B, 3, 1008, 1008] preprocessed images + point_inputs: {"point_coords": [B, N, 2], "point_labels": [B, N]} in 1008x1008 pixel space + box_inputs: [B, 2, 2] box corners (top-left, bottom-right) in 1008x1008 pixel space + mask_inputs: [B, 1, H, W] coarse mask logits to refine + Returns: + [B, 1, image_size, image_size] high-res mask logits + """ + bb = self.detector.backbone["vision_backbone"] + if bb.multiplex: + _, _, tracker_features, tracker_positions = bb(images, tracker_mode="interactive") + else: + _, _, tracker_features, tracker_positions = bb(images, need_tracker=True) + if self.detector.scalp > 0: + tracker_features = tracker_features[:-self.detector.scalp] + tracker_positions = tracker_positions[:-self.detector.scalp] + + high_res = list(tracker_features[:-1]) + backbone_feat = tracker_features[-1] + B, C, H, W = backbone_feat.shape + # Add no-memory embedding (init frame path) + no_mem = getattr(self.tracker, 'interactivity_no_mem_embed', None) + if no_mem is None: + no_mem = getattr(self.tracker, 'no_mem_embed', None) + if no_mem is not None: + feat_flat = backbone_feat.flatten(2).permute(0, 2, 1) + feat_flat = feat_flat + cast_to_input(no_mem, feat_flat) + backbone_feat = feat_flat.view(B, H, W, C).permute(0, 3, 1, 2) + + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + _, high_res_masks, _, _ = self.tracker._forward_sam_heads( + backbone_features=backbone_feat, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + box_inputs=box_inputs, + high_res_features=high_res, + multimask_output=(0 < num_pts <= 1), + ) + return high_res_masks + + def forward_video(self, images, initial_masks, pbar=None, text_prompts=None, + new_det_thresh=0.5, max_objects=0, detect_interval=1): + """Track video with optional per-frame text-prompted detection.""" + bb = self.detector.backbone["vision_backbone"] + + def backbone_fn(frame, frame_idx=None): + trunk_out = bb.trunk(frame) + if bb.multiplex: + _, _, tf, tp = bb(frame, tracker_mode="propagation", cached_trunk=trunk_out, tracker_only=True) + else: + _, _, tf, tp = bb(frame, need_tracker=True, cached_trunk=trunk_out, tracker_only=True) + return tf, tp, trunk_out + + detect_fn = None + if text_prompts: + resizer = self.detector.backbone["language_backbone"]["resizer"] + resized = [(resizer(emb), m.bool() if m is not None else None) for emb, m in text_prompts] + def detect_fn(trunk_out): + all_scores, all_masks = [], [] + for emb, mask in resized: + det = self.detector.forward_from_trunk(trunk_out, emb, mask) + all_scores.append(det["scores"]) + all_masks.append(det["masks"]) + return {"scores": torch.cat(all_scores, dim=1), "masks": torch.cat(all_masks, dim=1)} + + if hasattr(self.tracker, 'track_video_with_detection'): + return self.tracker.track_video_with_detection( + backbone_fn, images, initial_masks, detect_fn, + new_det_thresh=new_det_thresh, max_objects=max_objects, + detect_interval=detect_interval, backbone_obj=bb, pbar=pbar) + # SAM3 (non-multiplex) — no detection support, requires initial masks + if initial_masks is None: + raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking") + return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb) diff --git a/comfy/ldm/sam3/sam.py b/comfy/ldm/sam3/sam.py new file mode 100644 index 000000000..75cb457cf --- /dev/null +++ b/comfy/ldm/sam3/sam.py @@ -0,0 +1,425 @@ +# SAM3 shared components: primitives, ViTDet backbone, FPN neck, position encodings. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.flux.layers import EmbedND +from comfy.ops import cast_to_input + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, sigmoid_output=False, device=None, dtype=None, operations=None): + super().__init__() + dims = [input_dim] + [hidden_dim] * (num_layers - 1) + [output_dim] + self.layers = nn.ModuleList([operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers)]) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < len(self.layers) - 1 else layer(x) + return torch.sigmoid(x) if self.sigmoid_output else x + + +class SAMAttention(nn.Module): + def __init__(self, embedding_dim, num_heads, downsample_rate=1, kv_in_dim=None, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + internal_dim = embedding_dim // downsample_rate + kv_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.q_proj = operations.Linear(embedding_dim, internal_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(internal_dim, embedding_dim, device=device, dtype=dtype) + + def forward(self, q, k, v): + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) + + +class TwoWayAttentionBlock(nn.Module): + def __init__(self, embedding_dim, num_heads, mlp_dim=2048, attention_downsample_rate=2, skip_first_layer_pe=False, device=None, dtype=None, operations=None): + super().__init__() + self.skip_first_layer_pe = skip_first_layer_pe + self.self_attn = SAMAttention(embedding_dim, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.cross_attn_image_to_token = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.mlp = nn.Sequential(operations.Linear(embedding_dim, mlp_dim, device=device, dtype=dtype), nn.ReLU(), operations.Linear(mlp_dim, embedding_dim, device=device, dtype=dtype)) + self.norm1 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + self.norm4 = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + + def forward(self, queries, keys, query_pe, key_pe): + if self.skip_first_layer_pe: + queries = self.norm1(self.self_attn(queries, queries, queries)) + else: + q = queries + query_pe + queries = self.norm1(queries + self.self_attn(q, q, queries)) + q, k = queries + query_pe, keys + key_pe + queries = self.norm2(queries + self.cross_attn_token_to_image(q, k, keys)) + queries = self.norm3(queries + self.mlp(queries)) + q, k = queries + query_pe, keys + key_pe + keys = self.norm4(keys + self.cross_attn_image_to_token(k, q, queries)) + return queries, keys + + +class TwoWayTransformer(nn.Module): + def __init__(self, depth=2, embedding_dim=256, num_heads=8, mlp_dim=2048, attention_downsample_rate=2, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + TwoWayAttentionBlock(embedding_dim, num_heads, mlp_dim, attention_downsample_rate, + skip_first_layer_pe=(i == 0), device=device, dtype=dtype, operations=operations) + for i in range(depth) + ]) + self.final_attn_token_to_image = SAMAttention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate, device=device, dtype=dtype, operations=operations) + self.norm_final = operations.LayerNorm(embedding_dim, device=device, dtype=dtype) + + def forward(self, image_embedding, image_pe, point_embedding): + queries, keys = point_embedding, image_embedding + for layer in self.layers: + queries, keys = layer(queries, keys, point_embedding, image_pe) + q, k = queries + point_embedding, keys + image_pe + queries = self.norm_final(queries + self.final_attn_token_to_image(q, k, keys)) + return queries, keys + + +class PositionEmbeddingRandom(nn.Module): + """Fourier feature positional encoding with random gaussian projection.""" + def __init__(self, num_pos_feats=64, scale=None): + super().__init__() + self.register_buffer("positional_encoding_gaussian_matrix", (scale or 1.0) * torch.randn(2, num_pos_feats)) + + def _encode(self, normalized_coords): + """Map normalized [0,1] coordinates to fourier features via random projection. Computes in fp32.""" + orig_dtype = normalized_coords.dtype + proj_matrix = self.positional_encoding_gaussian_matrix.to(device=normalized_coords.device, dtype=torch.float32) + projected = 2 * math.pi * (2 * normalized_coords.float() - 1) @ proj_matrix + return torch.cat([projected.sin(), projected.cos()], dim=-1).to(orig_dtype) + + def forward(self, size, device=None): + h, w = size + dev = device if device is not None else self.positional_encoding_gaussian_matrix.device + ones = torch.ones((h, w), device=dev, dtype=torch.float32) + norm_xy = torch.stack([(ones.cumsum(1) - 0.5) / w, (ones.cumsum(0) - 0.5) / h], dim=-1) + return self._encode(norm_xy).permute(2, 0, 1).unsqueeze(0) + + def forward_with_coords(self, pixel_coords, image_size): + norm = pixel_coords.clone() + norm[:, :, 0] /= image_size[1] + norm[:, :, 1] /= image_size[0] + return self._encode(norm) + + +# ViTDet backbone + FPN neck + +def window_partition(x: torch.Tensor, window_size: int): + B, H, W, C = x.shape + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw, hw): + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def rope_2d(end_x: int, end_y: int, dim: int, theta: float = 10000.0, scale_pos: float = 1.0): + """Generate 2D axial RoPE using flux EmbedND. Returns [1, 1, HW, dim//2, 2, 2].""" + t = torch.arange(end_x * end_y, dtype=torch.float32) + ids = torch.stack([(t % end_x) * scale_pos, + torch.div(t, end_x, rounding_mode="floor") * scale_pos], dim=-1) + return EmbedND(dim=dim, theta=theta, axes_dim=[dim // 2, dim // 2])(ids.unsqueeze(0)) + + +class _ViTMLP(nn.Module): + def __init__(self, dim, mlp_ratio=4.0, device=None, dtype=None, operations=None): + super().__init__() + hidden = int(dim * mlp_ratio) + self.fc1 = operations.Linear(dim, hidden, device=device, dtype=dtype) + self.act = nn.GELU() + self.fc2 = operations.Linear(hidden, dim, device=device, dtype=dtype) + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class Attention(nn.Module): + """ViTDet multi-head attention with fused QKV projection.""" + + def __init__(self, dim, num_heads=8, qkv_bias=True, use_rope=False, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.use_rope = use_rope + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype) + self.proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + def forward(self, x, freqs_cis=None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) + q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0) + if self.use_rope and freqs_cis is not None: + q, k = apply_rope(q, k, freqs_cis) + return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False)) + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, window_size=0, use_rope=False, device=None, dtype=None, operations=None): + super().__init__() + self.window_size = window_size + self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.attn = Attention(dim, num_heads, qkv_bias, use_rope, device=device, dtype=dtype, operations=operations) + self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype) + self.mlp = _ViTMLP(dim, mlp_ratio, device=device, dtype=dtype, operations=operations) + + def forward(self, x, freqs_cis=None): + shortcut = x + x = self.norm1(x) + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + x = x.view(x.shape[0], self.window_size * self.window_size, -1) + x = self.attn(x, freqs_cis=freqs_cis) + x = x.view(-1, self.window_size, self.window_size, x.shape[-1]) + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + else: + B, H, W, C = x.shape + x = x.view(B, H * W, C) + x = self.attn(x, freqs_cis=freqs_cis) + x = x.view(B, H, W, C) + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=14, in_chans=3, embed_dim=1024, device=None, dtype=None, operations=None): + super().__init__() + self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False, device=device, dtype=dtype) + + def forward(self, x): + return self.proj(x) + + +class ViTDet(nn.Module): + def __init__(self, img_size=1008, patch_size=14, embed_dim=1024, depth=32, num_heads=16, mlp_ratio=4.625, qkv_bias=True, window_size=24, + global_att_blocks=(7, 15, 23, 31), use_rope=True, pretrain_img_size=336, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.embed_dim = embed_dim + self.num_heads = num_heads + self.global_att_blocks = set(global_att_blocks) + + self.patch_embed = PatchEmbed(patch_size, 3, embed_dim, device=device, dtype=dtype, operations=operations) + + num_patches = (pretrain_img_size // patch_size) ** 2 + 1 # +1 for cls token + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim, device=device, dtype=dtype)) + + self.ln_pre = operations.LayerNorm(embed_dim, device=device, dtype=dtype) + + grid_size = img_size // patch_size + pretrain_grid = pretrain_img_size // patch_size + + self.blocks = nn.ModuleList() + for i in range(depth): + is_global = i in self.global_att_blocks + self.blocks.append(Block( + embed_dim, num_heads, mlp_ratio, qkv_bias, + window_size=0 if is_global else window_size, + use_rope=use_rope, + device=device, dtype=dtype, operations=operations, + )) + + if use_rope: + rope_scale = pretrain_grid / grid_size + self.register_buffer("freqs_cis", rope_2d(grid_size, grid_size, embed_dim // num_heads, scale_pos=rope_scale), persistent=False) + self.register_buffer("freqs_cis_window", rope_2d(window_size, window_size, embed_dim // num_heads), persistent=False) + else: + self.freqs_cis = None + self.freqs_cis_window = None + + def _get_pos_embed(self, num_tokens): + pos = self.pos_embed + if pos.shape[1] == num_tokens: + return pos + cls_pos = pos[:, :1] + spatial_pos = pos[:, 1:] + old_size = int(math.sqrt(spatial_pos.shape[1])) + new_size = int(math.sqrt(num_tokens - 1)) if num_tokens > 1 else old_size + spatial_2d = spatial_pos.reshape(1, old_size, old_size, -1).permute(0, 3, 1, 2) + tiles_h = new_size // old_size + 1 + tiles_w = new_size // old_size + 1 + tiled = spatial_2d.tile([1, 1, tiles_h, tiles_w])[:, :, :new_size, :new_size] + tiled = tiled.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1) + return torch.cat([cls_pos, tiled], dim=1) + + def forward(self, x): + x = self.patch_embed(x) + B, C, Hp, Wp = x.shape + x = x.permute(0, 2, 3, 1).reshape(B, Hp * Wp, C) + + pos = cast_to_input(self._get_pos_embed(Hp * Wp + 1), x) + x = x + pos[:, 1:Hp * Wp + 1] + + x = x.view(B, Hp, Wp, C) + x = self.ln_pre(x) + + freqs_cis_global = self.freqs_cis + freqs_cis_win = self.freqs_cis_window + if freqs_cis_global is not None: + freqs_cis_global = cast_to_input(freqs_cis_global, x) + if freqs_cis_win is not None: + freqs_cis_win = cast_to_input(freqs_cis_win, x) + + for block in self.blocks: + fc = freqs_cis_win if block.window_size > 0 else freqs_cis_global + x = block(x, freqs_cis=fc) + + return x.permute(0, 3, 1, 2) + + +class FPNScaleConv(nn.Module): + def __init__(self, in_dim, out_dim, scale, device=None, dtype=None, operations=None): + super().__init__() + if scale == 4.0: + self.dconv_2x2_0 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype) + self.dconv_2x2_1 = operations.ConvTranspose2d(in_dim // 2, in_dim // 4, kernel_size=2, stride=2, device=device, dtype=dtype) + proj_in = in_dim // 4 + elif scale == 2.0: + self.dconv_2x2 = operations.ConvTranspose2d(in_dim, in_dim // 2, kernel_size=2, stride=2, device=device, dtype=dtype) + proj_in = in_dim // 2 + elif scale == 1.0: + proj_in = in_dim + elif scale == 0.5: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + proj_in = in_dim + self.scale = scale + self.conv_1x1 = operations.Conv2d(proj_in, out_dim, kernel_size=1, device=device, dtype=dtype) + self.conv_3x3 = operations.Conv2d(out_dim, out_dim, kernel_size=3, padding=1, device=device, dtype=dtype) + + def forward(self, x): + if self.scale == 4.0: + x = F.gelu(self.dconv_2x2_0(x)) + x = self.dconv_2x2_1(x) + elif self.scale == 2.0: + x = self.dconv_2x2(x) + elif self.scale == 0.5: + x = self.pool(x) + x = self.conv_1x1(x) + x = self.conv_3x3(x) + return x + + +class PositionEmbeddingSine(nn.Module): + """2D sinusoidal position encoding (DETR-style) with result caching.""" + def __init__(self, num_pos_feats=256, temperature=10000.0, normalize=True, scale=None): + super().__init__() + assert num_pos_feats % 2 == 0 + self.half_dim = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + self.scale = scale if scale is not None else 2 * math.pi + self._cache = {} + + def _sincos(self, vals): + """Encode 1D values to interleaved sin/cos features.""" + freqs = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=vals.device) // 2) / self.half_dim) + raw = vals[..., None] * self.scale / freqs + return torch.stack((raw[..., 0::2].sin(), raw[..., 1::2].cos()), dim=-1).flatten(-2) + + def _encode_xy(self, x, y): + """Encode normalized x, y coordinates to sinusoidal features. Returns (pos_x, pos_y) each [N, half_dim].""" + dim_t = self.temperature ** (2 * (torch.arange(self.half_dim, dtype=torch.float32, device=x.device) // 2) / self.half_dim) + pos_x = x[:, None] * self.scale / dim_t + pos_y = y[:, None] * self.scale / dim_t + pos_x = torch.stack((pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2).flatten(1) + pos_y = torch.stack((pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2).flatten(1) + return pos_x, pos_y + + def encode_boxes(self, cx, cy, w, h): + """Encode box center + size to [N, d_model+2] features.""" + pos_x, pos_y = self._encode_xy(cx, cy) + return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + + def forward(self, x): + B, C, H, W = x.shape + key = (H, W, x.device) + if key not in self._cache: + gy = torch.arange(H, dtype=torch.float32, device=x.device) + gx = torch.arange(W, dtype=torch.float32, device=x.device) + if self.normalize: + gy, gx = gy / (H - 1 + 1e-6), gx / (W - 1 + 1e-6) + yy, xx = torch.meshgrid(gy, gx, indexing="ij") + self._cache[key] = torch.cat((self._sincos(yy), self._sincos(xx)), dim=-1).permute(2, 0, 1).unsqueeze(0) + return self._cache[key].expand(B, -1, -1, -1) + + +class SAM3VisionBackbone(nn.Module): + def __init__(self, embed_dim=1024, d_model=256, multiplex=False, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.trunk = ViTDet(embed_dim=embed_dim, device=device, dtype=dtype, operations=operations, **kwargs) + self.position_encoding = PositionEmbeddingSine(num_pos_feats=d_model, normalize=True) + self.multiplex = multiplex + + fpn_args = dict(device=device, dtype=dtype, operations=operations) + if multiplex: + scales = [4.0, 2.0, 1.0] + self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.propagation_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.interactive_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + else: + scales = [4.0, 2.0, 1.0, 0.5] + self.convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + self.sam2_convs = nn.ModuleList([FPNScaleConv(embed_dim, d_model, s, **fpn_args) for s in scales]) + + def forward(self, images, need_tracker=False, tracker_mode=None, cached_trunk=None, tracker_only=False): + backbone_out = cached_trunk if cached_trunk is not None else self.trunk(images) + + if tracker_only: + # Skip detector FPN when only tracker features are needed (video tracking) + if self.multiplex: + tracker_convs = self.propagation_convs if tracker_mode == "propagation" else self.interactive_convs + else: + tracker_convs = self.sam2_convs + tracker_features = [conv(backbone_out) for conv in tracker_convs] + tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features] + return None, None, tracker_features, tracker_positions + + features = [conv(backbone_out) for conv in self.convs] + positions = [cast_to_input(self.position_encoding(f), f) for f in features] + + if self.multiplex: + if tracker_mode == "propagation": + tracker_convs = self.propagation_convs + elif tracker_mode == "interactive": + tracker_convs = self.interactive_convs + else: + return features, positions, None, None + elif need_tracker: + tracker_convs = self.sam2_convs + else: + return features, positions, None, None + + tracker_features = [conv(backbone_out) for conv in tracker_convs] + tracker_positions = [cast_to_input(self.position_encoding(f), f) for f in tracker_features] + return features, positions, tracker_features, tracker_positions diff --git a/comfy/ldm/sam3/tracker.py b/comfy/ldm/sam3/tracker.py new file mode 100644 index 000000000..8f7481003 --- /dev/null +++ b/comfy/ldm/sam3/tracker.py @@ -0,0 +1,1785 @@ +# SAM3 video tracker: memory encoder, memory attention, SAM mask decoder/prompt encoder. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +try: + import cv2 + _HAS_CV2 = True +except ImportError: + from scipy import ndimage + _HAS_CV2 = False + +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.sam3.sam import rope_2d, PositionEmbeddingSine +from comfy.ops import cast_to_input +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.cascade.common import LayerNorm2d_op +from comfy.ldm.sam3.sam import MLP, PositionEmbeddingRandom +from comfy.ldm.sam3.sam import TwoWayTransformer as SAMTwoWayTransformer + +NO_OBJ_SCORE = -1024.0 + + +def to_spatial(x, H, W): + """Reshape (B, H*W, C) → (B, C, H, W).""" + return x.view(x.shape[0], H, W, -1).permute(0, 3, 1, 2) + +class MultiplexState: + """Tracks object-to-slot assignments for multiplex tracking. Provides mux/demux operations.""" + + def __init__(self, num_objects, multiplex_count, device, dtype): + self.multiplex_count = multiplex_count + self.device = device + self.dtype = dtype + self._build(num_objects) + + def mux(self, x): + """[N_obj, ...] -> [num_buckets, multiplex_count, ...]""" + out_shape = (self.num_buckets, self.multiplex_count) + x.shape[1:] + return (self.mux_matrix.to(device=x.device, dtype=x.dtype) @ x.reshape(self.total_valid_entries, -1)).view(out_shape) + + def demux(self, x): + """[num_buckets, multiplex_count, ...] -> [N_obj, ...]""" + out_shape = (self.total_valid_entries,) + x.shape[2:] + flat = x.reshape(self.num_buckets * self.multiplex_count, -1) + return (self.demux_matrix.to(device=x.device, dtype=x.dtype) @ flat).view(out_shape) + + def get_valid_object_mask(self): + """[num_buckets, multiplex_count] bool tensor, True for valid slots.""" + return (self.mux_matrix.sum(dim=1) > 0).reshape(self.num_buckets, self.multiplex_count) + + def _build(self, num_objects): + M = self.multiplex_count + self.num_buckets = (num_objects + M - 1) // M + self.total_valid_entries = num_objects + total_slots = self.num_buckets * M + self.mux_matrix = torch.zeros(total_slots, num_objects, device=self.device, dtype=self.dtype) + self.demux_matrix = torch.zeros(num_objects, total_slots, device=self.device, dtype=self.dtype) + oids = torch.arange(num_objects, device=self.device) + slots = (oids // M) * M + (oids % M) + self.mux_matrix[slots, oids] = 1.0 + self.demux_matrix[oids, slots] = 1.0 + + def add_objects(self, n_new): + """Grow multiplex state for n_new additional objects.""" + self._build(self.total_valid_entries + n_new) + +def _compute_mask_overlap(masks_a, masks_b): + """Max of IoU and IoM (intersection over minimum area). More robust to size differences.""" + a_flat = (masks_a > 0).float().flatten(1) + b_flat = (masks_b > 0).float().flatten(1) + intersection = a_flat @ b_flat.T + area_a = a_flat.sum(1, keepdim=True) + area_b = b_flat.sum(1, keepdim=True).T + iou = intersection / (area_a + area_b - intersection).clamp(min=1) + iom = intersection / torch.min(area_a.expand_as(iou), area_b.expand_as(iou)).clamp(min=1) + return torch.max(iou, iom) + + +def _nms_masks(masks, scores, thresh=0.5): + """Mask-based NMS using IoU+IoM overlap. Returns (filtered_masks, filtered_scores).""" + order = scores.argsort(descending=True) + masks, scores = masks[order], scores[order] + keep = [] + for i in range(masks.shape[0]): + if keep: + if _compute_mask_overlap(masks[i:i+1], masks[torch.tensor(keep, device=masks.device)]).max() >= thresh: + continue + keep.append(i) + return masks[keep], scores[keep] + + +def _get_connected_components(mask_bin): + """Get connected component labels and areas. mask_bin: [B, 1, H, W] uint8.""" + labels_list, areas_list = [], [] + for i in range(mask_bin.shape[0]): + m = mask_bin[i, 0].cpu().numpy() + if _HAS_CV2: + _, labeled, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8) + areas = stats[labeled, cv2.CC_STAT_AREA].astype('int32') + else: + labeled, num_features = ndimage.label(m) + areas = np.zeros_like(m, dtype=np.int32) + for c in range(1, num_features + 1): + component = labeled == c + areas[component] = component.sum() + labels_list.append(torch.from_numpy(labeled).to(mask_bin.device)) + areas_list.append(torch.from_numpy(areas).to(device=mask_bin.device, dtype=torch.int32)) + return torch.stack(labels_list).unsqueeze(1), torch.stack(areas_list).unsqueeze(1) + + +def fill_holes_in_mask_scores(mask, max_area=0): + """Remove small foreground sprinkles and fill small background holes using connected components.""" + if max_area <= 0: + return mask + + # Fill holes: small connected components in background → foreground + mask_bg = (mask <= 0).to(torch.uint8) + _, areas_bg = _get_connected_components(mask_bg) + small_bg = mask_bg.bool() & (areas_bg <= max_area) + mask = torch.where(small_bg, 0.1, mask) + + # Remove sprinkles: small connected components in foreground → background + # Only remove if area < min(max_area, half of total foreground area) + mask_fg = (mask > 0).to(torch.uint8) + fg_area_thresh = mask_fg.sum(dim=(2, 3), keepdim=True, dtype=torch.int32) + fg_area_thresh.floor_divide_(2).clamp_(max=max_area) + _, areas_fg = _get_connected_components(mask_fg) + small_fg = mask_fg.bool() & (areas_fg <= fg_area_thresh) + mask = torch.where(small_fg, -0.1, mask) + + return mask + + +def apply_rope_memory(q, k, freqs, num_heads, num_k_exclude_rope=0): + """Apply 2D axial RoPE to memory attention using flux rope format. + + Args: + q: [B, Nq, C] projected queries (current frame features) + k: [B, Nk, C] projected keys (memory tokens) + freqs: [1, Nq, dim//2, 2, 2] flux-format rotation matrices for one frame + num_heads: number of attention heads + num_k_exclude_rope: number of trailing k tokens to skip RoPE (object pointers) + """ + B, Nq, C = q.shape + head_dim = C // num_heads + + # freqs shape: [1, 1, Nq, dim//2, 2, 2] (heads broadcast dim already included) + q_h = q.view(B, Nq, num_heads, head_dim).transpose(1, 2) + q_h = apply_rope1(q_h, freqs) + q = q_h.transpose(1, 2).reshape(B, Nq, C) + + # Apply RoPE to k (excluding last num_k_exclude_rope tokens) + Nk = k.shape[1] + num_k_rope = Nk - num_k_exclude_rope + if num_k_rope > 0: + # Repeat freqs for multiple frames of spatial memory + Nf = freqs.shape[2] # spatial positions in one frame + if num_k_rope > Nf: + r = (num_k_rope + Nf - 1) // Nf + pe_k = freqs.repeat(1, 1, r, 1, 1, 1)[:, :, :num_k_rope] + else: + pe_k = freqs[:, :, :num_k_rope] + + k_h = k[:, :num_k_rope].view(B, num_k_rope, num_heads, head_dim).transpose(1, 2) + k_h = apply_rope1(k_h, pe_k) + k = k.clone() + k[:, :num_k_rope] = k_h.transpose(1, 2).reshape(B, num_k_rope, C) + + return q, k + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """1D sinusoidal positional encoding for temporal positions.""" + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + pos_embed = pos_inds.unsqueeze(-1) / dim_t + return torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + + +def _pad_to_buckets(tensor, target_buckets): + """Pad a [num_buckets, ...] tensor to target_buckets along dim 0 if needed.""" + if tensor.shape[0] >= target_buckets: + return tensor + pad_shape = (target_buckets - tensor.shape[0],) + tensor.shape[1:] + return torch.cat([tensor, torch.zeros(pad_shape, device=tensor.device, dtype=tensor.dtype)], dim=0) + + +def pack_masks(masks): + """Pack binary masks [*, H, W] to bit-packed [*, H, W//8] uint8. W must be divisible by 8.""" + binary = masks > 0 + shifts = torch.arange(8, device=masks.device) + return (binary.view(*masks.shape[:-1], -1, 8) * (1 << shifts)).sum(-1).byte() + + +def unpack_masks(packed): + """Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8].""" + shifts = torch.arange(8, device=packed.device) + return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool() + + +def _compute_backbone(backbone_fn, frame, frame_idx=None): + """Compute backbone features for a single frame. Returns (vision_feats, vision_pos, feat_sizes, features, trunk_out).""" + features, positions, trunk_out = backbone_fn(frame, frame_idx=frame_idx) + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in features] + vision_feats = [x.flatten(2).permute(0, 2, 1) for x in features] + vision_pos = [x.flatten(2).permute(0, 2, 1) for x in positions] + return vision_feats, vision_pos, feat_sizes, features, trunk_out + + +def collect_memory_tokens(output_dict, frame_idx, num_maskmem, maskmem_tpos_enc, device, + collect_image_feats=False, tpos_v2=False, num_buckets=None): + """Collect spatial memory, position encodings, and optionally image features from past frames.""" + to_cat_memory, to_cat_memory_pos = [], [] + to_cat_image_feat, to_cat_image_pos = [], [] + + def _append(out, tpos_idx): + feats = out["maskmem_features"].to(device) + if num_buckets is not None: + feats = _pad_to_buckets(feats, num_buckets) + to_cat_memory.append(feats.flatten(2).permute(0, 2, 1)) + enc = out["maskmem_pos_enc"][-1].to(device).flatten(2).permute(0, 2, 1) + if num_buckets is not None: + enc = _pad_to_buckets(enc, num_buckets) + tpos = cast_to_input(maskmem_tpos_enc[tpos_idx], enc) + to_cat_memory_pos.append(enc + tpos) + if collect_image_feats and "image_features" in out: + to_cat_image_feat.append(out["image_features"].to(device)) + to_cat_image_pos.append(out["image_pos_enc"].to(device) + tpos) + + cond_outputs = output_dict["cond_frame_outputs"] + for t, out in cond_outputs.items(): + if tpos_v2: + t_pos = frame_idx - t + tpos_idx = num_maskmem - t_pos - 1 if 0 < t_pos < num_maskmem else num_maskmem - 1 + else: + tpos_idx = num_maskmem - 1 + _append(out, tpos_idx) + + for t_pos in range(1, num_maskmem): + out = output_dict["non_cond_frame_outputs"].get(frame_idx - (num_maskmem - t_pos), None) + if out is None or out.get("maskmem_features") is None: + continue + _append(out, num_maskmem - t_pos - 1) + + return to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs + + +def compute_tpos_enc(rel_pos_list, device, d_model, proj_layer, dtype=None, max_abs_pos=None): + """Temporal position encoding for object pointers.""" + pos_enc = torch.tensor(rel_pos_list, dtype=torch.float32, device=device) / max((max_abs_pos or 2) - 1, 1) + pos_enc = get_1d_sine_pe(pos_enc, dim=d_model) + if dtype is not None: + pos_enc = pos_enc.to(dtype) + return proj_layer(pos_enc) + + +def forward_sam_heads(backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, + image_size, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + """Shared SAM prompt encoder + mask decoder forward for both SAM3 and SAM3.1 trackers.""" + device = backbone_features.device + # Batch size from inputs (mask_inputs may have N_obj > 1 while backbone is batch 1) + if mask_inputs is not None: + B = mask_inputs.shape[0] + elif box_inputs is not None: + B = box_inputs.shape[0] + elif point_inputs is not None: + B = point_inputs["point_coords"].shape[0] + else: + B = backbone_features.shape[0] + + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + else: + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + if mask_inputs is not None: + prompt_size = (prompt_encoder.image_embedding_size[0] * 4, prompt_encoder.image_embedding_size[1] * 4) + if mask_inputs.shape[-2:] != prompt_size: + sam_mask_prompt = F.interpolate(mask_inputs, size=prompt_size, mode="bilinear", align_corners=False, antialias=True) + else: + sam_mask_prompt = mask_inputs + else: + sam_mask_prompt = None + + sparse, dense = prompt_encoder(points=(sam_point_coords, sam_point_labels), boxes=box_inputs, masks=sam_mask_prompt) + sparse = cast_to_input(sparse, backbone_features) + dense = cast_to_input(dense, backbone_features) + image_pe = cast_to_input(prompt_encoder.get_dense_pe(), backbone_features) + + low_res_multimasks, ious, sam_output_tokens, object_score_logits = mask_decoder( + image_embeddings=backbone_features, image_pe=image_pe, + sparse_prompt_embeddings=sparse, dense_prompt_embeddings=dense, + high_res_features=high_res_features, multimask_output=multimask_output, return_all=True, + ) + + is_obj_appearing = object_score_logits > 0 + low_res_multimasks = torch.where(is_obj_appearing[:, None, None], low_res_multimasks, + torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_multimasks.dtype)) + high_res_multimasks = F.interpolate(low_res_multimasks, size=(image_size, image_size), mode="bilinear", align_corners=False) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + obj_ptr = obj_ptr_proj(sam_output_token) + obj_ptr = no_obj_fn(obj_ptr, is_obj_appearing) + + return low_res_masks, high_res_masks, obj_ptr, object_score_logits + + +def use_mask_as_output(backbone_features, high_res_features, mask_inputs, mask_downsample, + prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, image_size, backbone_stride): + """Shared mask-as-output for both SAM3 and SAM3.1 trackers.""" + out_scale, out_bias = 20.0, -10.0 + mask_inputs_float = cast_to_input(mask_inputs, backbone_features) + high_res_masks = mask_inputs_float * out_scale + out_bias + low_res_masks = F.interpolate(high_res_masks, size=(image_size // backbone_stride * 4,) * 2, + mode="bilinear", align_corners=False, antialias=True) + _, _, obj_ptr, _ = forward_sam_heads( + backbone_features, prompt_encoder, mask_decoder, obj_ptr_proj, no_obj_fn, + image_size, mask_inputs=mask_downsample(mask_inputs_float), high_res_features=high_res_features, + ) + is_obj_appearing = torch.any(mask_inputs.flatten(1) > 0.0, dim=1)[..., None] + alpha = is_obj_appearing.to(obj_ptr.dtype) + object_score_logits = out_scale * alpha + out_bias + return low_res_masks, high_res_masks, obj_ptr, object_score_logits + + +# Split attention with configurable input dims (for asymmetric cross-attention) +class SplitAttn(nn.Module): + def __init__(self, embed_dim, num_heads=1, kv_dim=None, internal_dim=None, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + kv_dim = kv_dim or embed_dim + internal_dim = internal_dim or embed_dim + self.q_proj = operations.Linear(embed_dim, internal_dim, device=device, dtype=dtype) + self.k_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.v_proj = operations.Linear(kv_dim, internal_dim, device=device, dtype=dtype) + self.out_proj = operations.Linear(internal_dim, embed_dim, device=device, dtype=dtype) + + def forward(self, q, k=None, v=None, rope=None, num_k_exclude_rope=0): + if k is None: + k = q + if v is None: + v = k + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) + out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False) + return self.out_proj(out) + + +class MemoryAttnLayer(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.self_attn = SplitAttn(d_model, num_heads, device=device, dtype=dtype, operations=operations) + self.cross_attn_image = SplitAttn(d_model, num_heads, kv_dim=kv_dim, device=device, dtype=dtype, operations=operations) + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, x, memory, memory_pos=None, rope=None, num_k_exclude_rope=0): + x = x + self.self_attn(self.norm1(x), rope=rope) + mem_k = memory + memory_pos if memory_pos is not None else memory + x = x + self.cross_attn_image(self.norm2(x), mem_k, memory, rope=rope, num_k_exclude_rope=num_k_exclude_rope) + normed = self.norm3(x) + x = x + self.linear2(F.relu(self.linear1(normed))) + return x + + +class MemoryAttnEncoder(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14, + device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + MemoryAttnLayer(d_model, num_heads, kv_dim, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + hw = image_size // patch_size + self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False) + + def forward(self, x, memory, src_pos=None, memory_pos=None, num_k_exclude_rope=0): + if src_pos is not None: + x = x + 0.1 * src_pos + + rope = self._rope.to(device=x.device) + for layer in self.layers: + x = layer(x, memory, memory_pos=memory_pos, rope=rope, num_k_exclude_rope=num_k_exclude_rope) + return self.norm(x) + + +class MemoryTransformer(nn.Module): + def __init__(self, d_model=256, num_heads=1, kv_dim=64, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = MemoryAttnEncoder(d_model, num_heads, kv_dim, dim_ff, num_layers, device=device, dtype=dtype, operations=operations) + + +def _upscale_masks(output_upscaling, conv_s0, conv_s1, src_out, high_res_features): + """Shared upscaling for SAM mask decoders: deconv + high-res feature integration.""" + dc1, ln1, act1, dc2, act2 = output_upscaling + if high_res_features is not None: + upscaled = act1(ln1(dc1(src_out) + conv_s1(high_res_features[1]))) + upscaled = act2(dc2(upscaled) + conv_s0(high_res_features[0])) + else: + upscaled = act2(dc2(act1(ln1(dc1(src_out))))) + return upscaled + + +class SAMMaskDecoder(nn.Module): + def __init__(self, d_model=256, num_multimask_outputs=3, device=None, dtype=None, operations=None): + super().__init__() + self.num_mask_tokens = num_multimask_outputs + 1 + + self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations) + + self.iou_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + self.mask_tokens = operations.Embedding(self.num_mask_tokens, d_model, device=device, dtype=dtype) + self.obj_score_token = operations.Embedding(1, d_model, device=device, dtype=dtype) + + # Output upscaling: d_model -> d_model//4 -> d_model//8 at 4x resolution + LN2d = LayerNorm2d_op(operations) + self.output_upscaling = nn.Sequential( + operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype), LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(), + operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(), + ) + + # High-res feature integration + self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype) + self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype) + + # Per-mask hypernetwork MLPs + self.output_hypernetworks_mlps = nn.ModuleList([ + MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations) + for _ in range(self.num_mask_tokens) + ]) + + self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_tokens, 3, device=device, dtype=dtype, operations=operations) + self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, + high_res_features=None, multimask_output=False, return_all=False): + B = sparse_prompt_embeddings.shape[0] + ref = sparse_prompt_embeddings + # Token order: [obj_score(1), iou(1), mask(num_mask_tokens)] + tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref), + cast_to_input(self.mask_tokens.weight, ref)], dim=0) + tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1) + + src = image_embeddings + if src.shape[0] != B: + src = src.expand(B, -1, -1, -1) + src = src + dense_prompt_embeddings + pos_src = image_pe.expand(B, -1, -1, -1) + + b, c, h, w = src.shape + src_flat = src.flatten(2).permute(0, 2, 1) + pos_flat = pos_src.flatten(2).permute(0, 2, 1) + + hs, src_out = self.transformer(src_flat, pos_flat, tokens) + + obj_score_token_out = hs[:, 0, :] + iou_token_out = hs[:, 1, :] + mask_tokens_out = hs[:, 2:2 + self.num_mask_tokens, :] + + src_out = src_out.permute(0, 2, 1).view(b, c, h, w) + upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features) + + hyper_in = torch.stack([ + mlp(mask_tokens_out[:, i, :]) for i, mlp in enumerate(self.output_hypernetworks_mlps) + ], dim=1) + + masks = (hyper_in @ upscaled.flatten(2)).view(B, self.num_mask_tokens, upscaled.shape[2], upscaled.shape[3]) + iou_pred = self.iou_prediction_head(iou_token_out) + object_score_logits = self.pred_obj_score_head(obj_score_token_out) + + if multimask_output: + out_masks = masks[:, 1:] + out_iou = iou_pred[:, 1:] + out_tokens = mask_tokens_out[:, 1:] + else: + out_masks = masks[:, 0:1] + out_iou = iou_pred[:, 0:1] + out_tokens = mask_tokens_out[:, 0:1] + + if return_all: + return out_masks, out_iou, out_tokens, object_score_logits + return out_masks, out_iou + + +class SAMPromptEncoder(nn.Module): + def __init__(self, d_model=256, image_embedding_size=(72, 72), input_image_size=(1008, 1008), device=None, dtype=None, operations=None): + super().__init__() + self.embed_dim = d_model + self.image_embedding_size = image_embedding_size + self.input_image_size = input_image_size + + self.pe_layer = PositionEmbeddingRandom(d_model // 2) + self.point_embeddings = nn.ModuleList([ + operations.Embedding(1, d_model, device=device, dtype=dtype) for _ in range(4) + ]) + self.not_a_point_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + + LN2d = LayerNorm2d_op(operations) + self.mask_downscaling = nn.Sequential( + operations.Conv2d(1, 4, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(4, device=device, dtype=dtype), nn.GELU(), + operations.Conv2d(4, 16, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(16, device=device, dtype=dtype), nn.GELU(), + operations.Conv2d(16, d_model, kernel_size=1, device=device, dtype=dtype), + ) + self.no_mask_embed = operations.Embedding(1, d_model, device=device, dtype=dtype) + + def get_dense_pe(self): + return self.pe_layer(self.image_embedding_size) + + def forward(self, points=None, boxes=None, masks=None): + ref = points[0] if points is not None else boxes if boxes is not None else masks + B = 1 + sparse = torch.empty((B, 0, self.embed_dim), device=ref.device, dtype=ref.dtype) + + if points is not None: + coords, labels = points + B = coords.shape[0] + # Pad with an extra point (label=-1) when no boxes are provided (matching reference) + if boxes is None: + coords = torch.cat([coords, torch.zeros(B, 1, 2, device=coords.device, dtype=coords.dtype)], dim=1) + labels = torch.cat([labels, -torch.ones(B, 1, device=labels.device, dtype=labels.dtype)], dim=1) + pe = self.pe_layer.forward_with_coords(coords + 0.5, self.input_image_size) + for i in range(4): + pe[labels == i] += cast_to_input(self.point_embeddings[i].weight, ref) + invalid = (labels == -1) + pe[invalid] = 0.0 + pe[invalid] += cast_to_input(self.not_a_point_embed.weight, ref) + sparse = torch.cat([sparse.expand(B, -1, -1), pe], dim=1) + + if boxes is not None: + B = boxes.shape[0] + corners = self.pe_layer.forward_with_coords((boxes.reshape(-1, 2, 2) + 0.5), self.input_image_size) + corners[:, 0] += cast_to_input(self.point_embeddings[2].weight, ref) + corners[:, 1] += cast_to_input(self.point_embeddings[3].weight, ref) + sparse = torch.cat([sparse.expand(B, -1, -1), corners], dim=1) + + if masks is not None: + dense = self.mask_downscaling(masks) + else: + dense = cast_to_input(self.no_mask_embed.weight, ref).reshape(1, -1, 1, 1).expand( + B, -1, self.image_embedding_size[0], self.image_embedding_size[1]) + + return sparse, dense + + +class CXBlock(nn.Module): + def __init__(self, dim=256, kernel_size=7, device=None, dtype=None, operations=None): + super().__init__() + self.dwconv = operations.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, device=device, dtype=dtype) + self.pwconv1 = operations.Linear(dim, 4 * dim, device=device, dtype=dtype) + self.pwconv2 = operations.Linear(4 * dim, dim, device=device, dtype=dtype) + self.gamma = nn.Parameter(torch.ones(dim, device=device, dtype=dtype)) + + def forward(self, x): + residual = x + x = self.dwconv(x).permute(0, 2, 3, 1) + x = self.pwconv2(F.gelu(self.pwconv1(self.norm(x)))) + x.mul_(cast_to_input(self.gamma, x)) + return residual + x.permute(0, 3, 1, 2) + + +class MaskDownSampler(nn.Module): + def __init__(self, out_dim=256, in_chans=1, channels=None, interpol_size=(1152, 1152), device=None, dtype=None, operations=None): + super().__init__() + self.interpol_size = list(interpol_size) if interpol_size else None + if channels is None: + channels = [4, 16, 64, out_dim] # SAM3 default + LN2d = LayerNorm2d_op(operations) + layers = [] + prev = in_chans + for ch in channels: + layers += [operations.Conv2d(prev, ch, kernel_size=3, stride=2, padding=1, device=device, dtype=dtype), + LN2d(ch, device=device, dtype=dtype), nn.GELU()] + prev = ch + layers.append(operations.Conv2d(prev, out_dim, kernel_size=1, device=device, dtype=dtype)) + self.encoder = nn.Sequential(*layers) + + def forward(self, x): + if self.interpol_size is not None and list(x.shape[-2:]) != self.interpol_size: + x = F.interpolate(x, size=self.interpol_size, mode="bilinear", align_corners=False, antialias=True) + return self.encoder(x) + + +class Fuser(nn.Module): + def __init__(self, dim=256, num_layers=2, device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.Sequential(*[CXBlock(dim, device=device, dtype=dtype, operations=operations) for _ in range(num_layers)]) + + def forward(self, x): + return self.layers(x) + + +# --- SAM3.1 Multiplex components --- + +class DecoupledMemoryAttnLayer(nn.Module): + """Decoupled cross-attention layer for SAM3.1: fuses image and memory projections.""" + + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, device=None, dtype=None, operations=None): + super().__init__() + self.num_heads = num_heads + # Self-attention projections (flat, not nested) + self.self_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.self_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # Cross-attention projections + self.cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_v_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.cross_attn_out_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # Image cross-attention (q/k only, fused with cross_attn) + self.image_cross_attn_q_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.image_cross_attn_k_proj = operations.Linear(d_model, d_model, device=device, dtype=dtype) + # FFN + self.linear1 = operations.Linear(d_model, dim_ff, device=device, dtype=dtype) + self.linear2 = operations.Linear(dim_ff, d_model, device=device, dtype=dtype) + self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype) + self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype) + + def forward(self, image, x, memory_image, memory, memory_image_pos=None, + rope=None, num_k_exclude_rope=0): + # Self-attention with RoPE + normed = self.norm1(x) + q = self.self_attn_q_proj(normed) + k = self.self_attn_k_proj(normed) + v = self.self_attn_v_proj(normed) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, 0) + x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) + + # Decoupled cross-attention: fuse image and memory projections + normed = self.norm2(x) + q = self.image_cross_attn_q_proj(image) + self.cross_attn_q_proj(normed) + k = self.image_cross_attn_k_proj(memory_image) + self.cross_attn_k_proj(memory) + if memory_image_pos is not None: + k = k + memory_image_pos + v = self.cross_attn_v_proj(memory) + if rope is not None: + q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope) + x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)) + + # FFN + x = x + self.linear2(F.gelu(self.linear1(self.norm3(x)))) + return image, x + + +class DecoupledMemoryEncoder(nn.Module): + """Memory attention encoder for SAM3.1 with decoupled cross-attention.""" + + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, image_size=1008, patch_size=14, + device=None, dtype=None, operations=None): + super().__init__() + self.layers = nn.ModuleList([ + DecoupledMemoryAttnLayer(d_model, num_heads, dim_ff, device=device, dtype=dtype, operations=operations) + for _ in range(num_layers) + ]) + self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype) + hw = image_size // patch_size + self.register_buffer("_rope", rope_2d(hw, hw, d_model // num_heads), persistent=False) + + def forward(self, x, memory, memory_pos=None, src_pos=None, num_k_exclude_rope=0, + memory_image=None, memory_image_pos=None): + image = x # constant residual for decoupled cross-attention + output = x + if src_pos is not None: + output = output + 0.1 * src_pos + + B, _, C = x.shape + rope = self._rope.to(device=x.device) + + # memory_image: raw backbone features from past frames for decoupled cross-attention + if memory_image is None: + # Fallback: use spatial portion of memory (without obj pointers) + num_spatial = memory.shape[1] - num_k_exclude_rope + memory_image = memory[:, :num_spatial] + memory_image_pos = memory_pos[:, :num_spatial] if memory_pos is not None else None + # Pad memory_image to match memory length (zeros for obj pointer tokens) + if memory_image.shape[1] < memory.shape[1]: + pad_len = memory.shape[1] - memory_image.shape[1] + pad = torch.zeros(B, pad_len, C, device=memory.device, dtype=memory.dtype) + memory_image = torch.cat([memory_image, pad], dim=1) + if memory_image_pos is not None: + ptr_pos = memory_pos[:, -pad_len:] if memory_pos is not None else torch.zeros_like(pad) + memory_image_pos = torch.cat([memory_image_pos, ptr_pos], dim=1) + + for layer in self.layers: + image, output = layer(image, output, memory_image, memory, + memory_image_pos=memory_image_pos, rope=rope, + num_k_exclude_rope=num_k_exclude_rope) + + return self.norm(output) + + +class DecoupledMemoryTransformer(nn.Module): + def __init__(self, d_model=256, num_heads=1, dim_ff=2048, num_layers=4, device=None, dtype=None, operations=None): + super().__init__() + self.encoder = DecoupledMemoryEncoder(d_model, num_heads, dim_ff, num_layers, + device=device, dtype=dtype, operations=operations) + + +class MemoryBackbone(nn.Module): + """Memory encoder: downsamples mask, fuses with pixel features, optionally compresses.""" + + def __init__(self, d_model=256, out_dim=None, in_chans=1, channels=None, device=None, dtype=None, operations=None): + super().__init__() + self.mask_downsampler = MaskDownSampler(d_model, in_chans=in_chans, channels=channels, device=device, dtype=dtype, operations=operations) + self.pix_feat_proj = operations.Conv2d(d_model, d_model, kernel_size=1, device=device, dtype=dtype) + self.fuser = Fuser(d_model, num_layers=2, device=device, dtype=dtype, operations=operations) + self.has_out_proj = out_dim is not None and out_dim != d_model + if self.has_out_proj: + self.out_proj = operations.Conv2d(d_model, out_dim, kernel_size=1, device=device, dtype=dtype) + feat_dim = out_dim + else: + feat_dim = d_model + self.position_encoding = PositionEmbeddingSine(num_pos_feats=feat_dim, normalize=True) + + def forward(self, image_features, mask_for_mem, skip_mask_sigmoid=False): + if not skip_mask_sigmoid: + mask_for_mem = mask_for_mem.sigmoid() + mask_features = self.mask_downsampler(cast_to_input(mask_for_mem, image_features)) + if mask_features.shape[-2:] != image_features.shape[-2:]: + mask_features = F.interpolate(mask_features, size=image_features.shape[-2:], mode="bilinear", align_corners=False) + features = self.pix_feat_proj(image_features) + mask_features + features = self.fuser(features) + if self.has_out_proj: + features = self.out_proj(features) + pos = cast_to_input(self.position_encoding(features), features) + return {"vision_features": features, "vision_pos_enc": [pos]} + + +class MultiplexMaskDecoder(nn.Module): + """SAM mask decoder for SAM3.1 multiplex: predicts masks for num_multiplex objects simultaneously. + + Uses multimask_outputs_only=True: num_mask_output_per_object = num_multimask_outputs (no +1). + Hypernetwork MLPs are shared across multiplex objects. + Token order: [obj_score_token(M), iou_token(M), mask_tokens(M*T)]. + """ + + def __init__(self, d_model=256, num_multiplex=16, num_multimask_outputs=3, device=None, dtype=None, operations=None): + super().__init__() + self.num_multiplex = num_multiplex + self.num_mask_output_per_object = num_multimask_outputs # 3 (multimask_outputs_only) + total_mask_tokens = num_multiplex * self.num_mask_output_per_object # 48 + + self.transformer = SAMTwoWayTransformer(depth=2, embedding_dim=d_model, num_heads=8, mlp_dim=2048, device=device, dtype=dtype, operations=operations) + + self.obj_score_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype) + self.iou_token = operations.Embedding(num_multiplex, d_model, device=device, dtype=dtype) + self.mask_tokens = operations.Embedding(total_mask_tokens, d_model, device=device, dtype=dtype) + + LN2d = LayerNorm2d_op(operations) + self.output_upscaling = nn.Sequential( + operations.ConvTranspose2d(d_model, d_model // 4, kernel_size=2, stride=2, device=device, dtype=dtype), + LN2d(d_model // 4, device=device, dtype=dtype), nn.GELU(), + operations.ConvTranspose2d(d_model // 4, d_model // 8, kernel_size=2, stride=2, device=device, dtype=dtype), nn.GELU(), + ) + self.conv_s0 = operations.Conv2d(d_model, d_model // 8, kernel_size=1, device=device, dtype=dtype) + self.conv_s1 = operations.Conv2d(d_model, d_model // 4, kernel_size=1, device=device, dtype=dtype) + + # Shared across all multiplex objects (one per mask output) + self.output_hypernetworks_mlps = nn.ModuleList([ + MLP(d_model, d_model, d_model // 8, 3, device=device, dtype=dtype, operations=operations) + for _ in range(self.num_mask_output_per_object) + ]) + self.iou_prediction_head = MLP(d_model, d_model, self.num_mask_output_per_object, 3, device=device, dtype=dtype, operations=operations) + self.pred_obj_score_head = MLP(d_model, d_model, 1, 3, device=device, dtype=dtype, operations=operations) + + def forward(self, image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, + high_res_features=None, multimask_output=False, return_all=False, extra_per_object_embeddings=None): + B = sparse_prompt_embeddings.shape[0] + M = self.num_multiplex + T = self.num_mask_output_per_object + + # Token order: [obj_score(M), iou(M), mask(M*T)] + ref = sparse_prompt_embeddings + mask_tokens = cast_to_input(self.mask_tokens.weight, ref) + if extra_per_object_embeddings is not None: + mask_tokens = mask_tokens.view(1, M, T, -1).expand(B, -1, -1, -1) + extra_per_object_embeddings.unsqueeze(2) + mask_tokens = mask_tokens.flatten(1, 2) # [B, M*T, C] + other_tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref)], dim=0).unsqueeze(0).expand(B, -1, -1) + tokens = torch.cat([other_tokens, mask_tokens, sparse_prompt_embeddings], dim=1) + else: + tokens = torch.cat([cast_to_input(self.obj_score_token.weight, ref), + cast_to_input(self.iou_token.weight, ref), mask_tokens], dim=0) + tokens = torch.cat([tokens.unsqueeze(0).expand(B, -1, -1), sparse_prompt_embeddings], dim=1) + + src = image_embeddings + if src.shape[0] != B: + src = src.expand(B, -1, -1, -1) + src = src + dense_prompt_embeddings + pos_src = image_pe.expand(B, -1, -1, -1) + + b, c, h, w = src.shape + hs, src_out = self.transformer(src.flatten(2).permute(0, 2, 1), pos_src.flatten(2).permute(0, 2, 1), tokens) + + # Parse output tokens + obj_score_token_out = hs[:, :M] + iou_token_out = hs[:, M:2 * M] + mask_tokens_out = hs[:, 2 * M:2 * M + M * T] + + src_out = src_out.permute(0, 2, 1).view(b, c, h, w) + upscaled = _upscale_masks(self.output_upscaling, self.conv_s0, self.conv_s1, src_out, high_res_features) + + # Reshape mask tokens to [B, M, T, C] and apply shared hypernetwork MLPs per mask output index + mask_tokens_2d = mask_tokens_out.view(B, M, T, -1) + hyper_in = torch.stack([ + self.output_hypernetworks_mlps[i](mask_tokens_2d[:, :, i, :]) # [B, M, C//8] + for i in range(T) + ], dim=2) # [B, M, T, C//8] + + # Generate masks: [B, M*T, H*W] -> [B, M, T, H, W] + masks = torch.bmm(hyper_in.flatten(1, 2), upscaled.flatten(2)).view(b, M, T, upscaled.shape[2], upscaled.shape[3]) + + # IoU and object scores + iou_pred = self.iou_prediction_head(iou_token_out).view(b, M, T) + object_score_logits = self.pred_obj_score_head(obj_score_token_out) # [B, M, 1] + + # multimask_outputs_only: always output all T masks (no singlemask token) + sam_tokens_out = mask_tokens_2d[:, :, 0:1] # [B, M, 1, C] + + if return_all: + return masks, iou_pred, sam_tokens_out, object_score_logits + return masks, iou_pred + + +class SAM3Tracker(nn.Module): + def __init__(self, d_model=256, mem_dim=64, num_maskmem=7, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + + # Memory attention transformer + self.transformer = MemoryTransformer(d_model, num_heads=1, kv_dim=mem_dim, dim_ff=2048, num_layers=4, + device=device, dtype=dtype, operations=operations) + # SAM components + self.sam_mask_decoder = SAMMaskDecoder(d_model, device=device, dtype=dtype, operations=operations) + self.sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations) + + # Memory backbone + self.maskmem_backbone = MemoryBackbone(d_model, out_dim=mem_dim, device=device, dtype=dtype, operations=operations) + + # Standalone parameters + self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype)) + self.no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype)) + self.register_buffer("no_mem_pos_enc", torch.zeros(1, 1, d_model, device=device, dtype=dtype)) # checkpoint key, unused in forward + self.no_obj_embed_spatial = nn.Parameter(torch.zeros(1, mem_dim, device=device, dtype=dtype)) + self.no_obj_ptr = nn.Parameter(torch.zeros(1, d_model, device=device, dtype=dtype)) + + # Object pointer projection + self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype) + + # Mask downsample: Conv2d stride 4 to reduce GT mask to SAM logit scale + self.mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype) + + # Config + self.d_model = d_model + self.mem_dim = mem_dim + self.num_maskmem = num_maskmem + self.image_size = 1008 + self.backbone_stride = 14 + self.max_obj_ptrs_in_encoder = 16 + self.sigmoid_scale_for_mem_enc = 20.0 + self.sigmoid_bias_for_mem_enc = -10.0 + + def _no_obj_blend(self, obj_ptr, is_obj): + alpha = is_obj.to(obj_ptr.dtype) + return torch.lerp(cast_to_input(self.no_obj_ptr, obj_ptr), obj_ptr, alpha) + + def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + return forward_sam_heads(backbone_features, self.sam_prompt_encoder, self.sam_mask_decoder, + self.obj_ptr_proj, self._no_obj_blend, self.image_size, + point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + return use_mask_as_output(backbone_features, high_res_features, mask_inputs, + self.mask_downsample, self.sam_prompt_encoder, self.sam_mask_decoder, + self.obj_ptr_proj, self._no_obj_blend, self.image_size, self.backbone_stride) + + def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, output_dict, num_frames): + """Fuse current frame features with memory from previous frames.""" + B = current_vision_feats[-1].shape[0] + C = self.d_model + H, W = feat_sizes[-1] + device = current_vision_feats[-1].device + + if self.num_maskmem == 0: + return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W) + + if is_init_cond_frame: + # First conditioning frame: no memory yet, add no_mem_embed + pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + to_cat_memory, to_cat_memory_pos, _, _, cond_outputs = collect_memory_tokens( + output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device) + + max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder) + pos_and_ptrs = [] + for t, out in cond_outputs.items(): + if t <= frame_idx: + pos_and_ptrs.append(((frame_idx - t), out["obj_ptr"].to(device))) + for t_diff in range(1, max_obj_ptrs): + t = frame_idx - t_diff + if t < 0: + break + out = output_dict["non_cond_frame_outputs"].get(t, None) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"].to(device))) + + num_obj_ptr_tokens = 0 + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + obj_ptrs = torch.stack(ptrs_list, dim=1) # [B, N, C=256] + + # Temporal position encoding for pointers + obj_pos = compute_tpos_enc( + list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj, + max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype + ) # [N, mem_dim=64] + obj_pos = obj_pos.unsqueeze(0).expand(B, -1, -1) # [B, N, 64] + + # Split each 256-dim pointer into 4 x 64-dim tokens + if self.mem_dim < C: + N = obj_ptrs.shape[1] + obj_ptrs = obj_ptrs.view(B, N, C // self.mem_dim, self.mem_dim) # [B, N, 4, 64] + obj_ptrs = obj_ptrs.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64] + obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, C // self.mem_dim, -1) + obj_pos = obj_pos.reshape(B, N * (C // self.mem_dim), self.mem_dim) # [B, N*4, 64] + + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[1] + + if len(to_cat_memory) == 0: + # No memory available yet, add no_mem_embed + pix_feat = current_vision_feats[-1] + cast_to_input(self.no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + # Concatenate all memory and position encodings [B, total_mem, mem_dim=64] + memory = torch.cat(to_cat_memory, dim=1) + memory_pos = torch.cat(to_cat_memory_pos, dim=1) + + # Run memory attention encoder + pix_feat = current_vision_feats[-1] # [B, HW, C] + src_pos = current_vision_pos_embeds[-1] # [B, HW, C] + + pix_feat_with_mem = self.transformer.encoder( + x=pix_feat, + memory=memory, + src_pos=src_pos, + memory_pos=memory_pos, + num_k_exclude_rope=num_obj_ptr_tokens, + ) + return to_spatial(pix_feat_with_mem, H, W) + + def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False): + """Encode predicted mask into memory features.""" + if is_mask_from_pts: + mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype) + else: + mask_for_mem = torch.sigmoid(pred_masks_high_res) + + mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc) + + maskmem_out = self.maskmem_backbone(pix_feat, mask_for_mem, skip_mask_sigmoid=True) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + # Add no_obj_embed for occluded objects + alpha = (object_score_logits > 0).to(maskmem_features.dtype)[..., None, None] + no_obj = cast_to_input(self.no_obj_embed_spatial, maskmem_features)[..., None, None].expand_as(maskmem_features) + return maskmem_features + (1 - alpha) * no_obj, maskmem_pos_enc + + def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, mask_inputs, output_dict, + num_frames, point_inputs=None): + """Track one frame: fuse with memory, predict mask, encode memory.""" + current_out = {} + + # High-res features for SAM head [stride-8, stride-4] + if len(current_vision_feats) > 1: + high_res_features = [ + x.view(x.shape[0], feat_sizes[i][0], feat_sizes[i][1], -1).permute(0, 3, 1, 2) + for i, x in enumerate(current_vision_feats[:-1]) + ] + else: + high_res_features = None + + # Top-level feature for memory + H, W = feat_sizes[-1] + + if mask_inputs is not None: + # Conditioning frame: use mask directly + pix_feat = to_spatial(current_vision_feats[-1], H, W) + sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) + else: + # Track frame: fuse with memory, then SAM decoder + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + output_dict=output_dict, + num_frames=num_frames, + ) + # Use multimask for point prompts on init frames (picks best of 3 candidates) + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = is_init_cond_frame and 0 < num_pts <= 1 + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + + (low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs + + # Clean low-res masks: remove sprinkles and fill holes + low_res_masks = fill_holes_in_mask_scores(low_res_masks, max_area=200) + high_res_masks = F.interpolate(low_res_masks, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + current_out["object_score_logits"] = object_score_logits + + # Encode memory + if self.num_maskmem > 0: + pix_feat = to_spatial(current_vision_feats[-1], H, W) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, + pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None): + vision_feats, vision_pos, feat_sizes, _, _ = _compute_backbone(backbone_fn, frame, frame_idx) + # SAM3: drop last FPN level + return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1] + + def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None): + """Track one object, computing backbone per frame to save VRAM.""" + N = images.shape[0] + device, dt = images.device, images.dtype + output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} + all_masks = [] + + for frame_idx in tqdm(range(N), desc="tracking"): + vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame( + backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx) + mask_input = None + if frame_idx == 0: + mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt), + size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + mask_input = (mask_input > 0.5).to(dt) + + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=(frame_idx == 0), + current_vision_feats=vision_feats, current_vision_pos_embeds=vision_pos, + feat_sizes=feat_sizes, mask_inputs=mask_input, output_dict=output_dict, num_frames=N) + + if frame_idx == 0: + output_dict["cond_frame_outputs"][frame_idx] = current_out + else: + output_dict["non_cond_frame_outputs"][frame_idx] = current_out + lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder) + for old_idx in list(output_dict["non_cond_frame_outputs"]): + if old_idx < frame_idx - lookback: + del output_dict["non_cond_frame_outputs"][old_idx] + # Move masks to CPU immediately to free VRAM + all_masks.append(current_out["pred_masks_high_res"].to(comfy.model_management.intermediate_device())) + if pbar is not None: + pbar.update(1) + + return torch.cat(all_masks, dim=0) # [N, 1, H, W] + + def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs): + """Track one or more objects across video frames. + + Args: + backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame + images: [N, 3, 1008, 1008] video frames + initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object) + pbar: optional progress bar + + Returns: + [N, N_obj, image_size, image_size] predicted mask logits per frame per object + """ + N_obj = initial_masks.shape[0] + per_object = [] + for obj_idx in range(N_obj): + obj_masks = self._track_single_object( + backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar) + per_object.append(obj_masks) + + return torch.cat(per_object, dim=1) # [N, N_obj, H, W] + + +class SAM31Tracker(nn.Module): + """SAM3.1 multiplex tracker: decoupled memory attention, dual decoder, 16-object multiplex.""" + + def __init__(self, d_model=256, mem_dim=256, num_maskmem=7, num_multiplex=16, device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.d_model = d_model + self.mem_dim = mem_dim + self.num_maskmem = num_maskmem + self.num_multiplex = num_multiplex + self.image_size = 1008 + self.backbone_stride = 14 + self.max_obj_ptrs_in_encoder = 16 + self.sigmoid_scale_for_mem_enc = 2.0 + self.sigmoid_bias_for_mem_enc = -1.0 + + # Memory attention (decoupled cross-attention, 8 heads matching reference) + self.transformer = DecoupledMemoryTransformer(d_model, num_heads=8, dim_ff=2048, num_layers=4, + device=device, dtype=dtype, operations=operations) + + # Propagation decoder (multiplex: 16 objects, multimask_outputs_only) + self.sam_mask_decoder = MultiplexMaskDecoder(d_model, num_multiplex, num_multimask_outputs=3, + device=device, dtype=dtype, operations=operations) + # Interactive decoder (single object, same as SAM3) + self.interactive_sam_mask_decoder = SAMMaskDecoder(d_model, num_multimask_outputs=3, + device=device, dtype=dtype, operations=operations) + self.interactive_sam_prompt_encoder = SAMPromptEncoder(d_model, device=device, dtype=dtype, operations=operations) + + # Memory backbone (mem_dim=256, no out_proj compression) + self.maskmem_backbone = MemoryBackbone(d_model, in_chans=num_multiplex * 2, channels=[16, 64, 256, 1024], + device=device, dtype=dtype, operations=operations) + + # Standalone parameters + self.maskmem_tpos_enc = nn.Parameter(torch.zeros(num_maskmem, 1, 1, mem_dim, device=device, dtype=dtype)) + self.no_obj_embed_spatial = nn.Parameter(torch.zeros(num_multiplex, mem_dim, device=device, dtype=dtype)) + self.interactivity_no_mem_embed = nn.Parameter(torch.zeros(1, 1, d_model, device=device, dtype=dtype)) + + # Object pointer projection + self.obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + self.obj_ptr_tpos_proj = operations.Linear(d_model, mem_dim, device=device, dtype=dtype) + self.no_obj_ptr_linear = operations.Linear(d_model, d_model, device=device, dtype=dtype) + self.interactive_obj_ptr_proj = MLP(d_model, d_model, d_model, 3, device=device, dtype=dtype, operations=operations) + + # Interactive mask downsample + self.interactive_mask_downsample = operations.Conv2d(1, 1, kernel_size=4, stride=4, device=device, dtype=dtype) + + # Multiplex validity embeddings + self.output_valid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype)) + self.output_invalid_embed = nn.Parameter(torch.zeros(num_multiplex, d_model, device=device, dtype=dtype)) + + # Position encoding for image (used by multiplex decoder) + self.image_pe_layer = PositionEmbeddingRandom(d_model // 2) + + def _no_obj_blend(self, obj_ptr, is_obj): + alpha = is_obj.to(obj_ptr.dtype) + return torch.lerp(self.no_obj_ptr_linear(obj_ptr), obj_ptr, alpha) + + def _forward_sam_heads(self, backbone_features, point_inputs=None, mask_inputs=None, box_inputs=None, + high_res_features=None, multimask_output=False): + return forward_sam_heads(backbone_features, self.interactive_sam_prompt_encoder, self.interactive_sam_mask_decoder, + self.interactive_obj_ptr_proj, self._no_obj_blend, self.image_size, + point_inputs, mask_inputs, box_inputs, high_res_features, multimask_output) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + return use_mask_as_output(backbone_features, high_res_features, mask_inputs, + self.interactive_mask_downsample, self.interactive_sam_prompt_encoder, + self.interactive_sam_mask_decoder, self.interactive_obj_ptr_proj, + self._no_obj_blend, self.image_size, self.backbone_stride) + + def _prepare_memory_conditioned_features(self, frame_idx, is_init_cond_frame, current_vision_feats, + current_vision_pos_embeds, feat_sizes, output_dict, num_frames, + multiplex_state=None): + B = current_vision_feats[-1].shape[0] + C = self.d_model + H, W = feat_sizes[-1] + device = current_vision_feats[-1].device + num_buc = multiplex_state.num_buckets if multiplex_state is not None else None + + if self.num_maskmem == 0: + return current_vision_feats[-1].permute(0, 2, 1).view(B, C, H, W) + + if is_init_cond_frame: + pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + to_cat_memory, to_cat_memory_pos, to_cat_image_feat, to_cat_image_pos, cond_outputs = collect_memory_tokens( + output_dict, frame_idx, self.num_maskmem, self.maskmem_tpos_enc, device, + collect_image_feats=True, tpos_v2=True, num_buckets=num_buc) + + max_obj_ptrs = min(num_frames, self.max_obj_ptrs_in_encoder) + pos_and_ptrs = [] + for t, out in cond_outputs.items(): + if t <= frame_idx and "obj_ptr" in out: + ptr = out["obj_ptr"].to(device) + if num_buc is not None: + ptr = _pad_to_buckets(ptr, num_buc) + pos_and_ptrs.append(((frame_idx - t), ptr)) + for t_diff in range(1, max_obj_ptrs): + t = frame_idx - t_diff + if t < 0: + break + out = output_dict["non_cond_frame_outputs"].get(t, None) + if out is not None and "obj_ptr" in out: + ptr = out["obj_ptr"].to(device) + if num_buc is not None: + ptr = _pad_to_buckets(ptr, num_buc) + pos_and_ptrs.append((t_diff, ptr)) + + num_obj_ptr_tokens = 0 + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + obj_ptrs = torch.stack(ptrs_list, dim=1) # [num_buckets, N, M, C] + B_ptr = obj_ptrs.shape[0] + N_ptrs = obj_ptrs.shape[1] + M = obj_ptrs.shape[2] + obj_ptrs = obj_ptrs.reshape(B_ptr, N_ptrs * M, -1) + obj_pos = compute_tpos_enc(list(pos_list), device, self.d_model, self.obj_ptr_tpos_proj, + max_abs_pos=max_obj_ptrs, dtype=current_vision_feats[-1].dtype) + obj_pos = obj_pos.unsqueeze(0).expand(B_ptr, -1, -1) + obj_pos = obj_pos.unsqueeze(2).expand(-1, -1, M, -1).reshape(B_ptr, N_ptrs * M, -1) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[1] + + if len(to_cat_memory) == 0: + pix_feat = current_vision_feats[-1] + cast_to_input(self.interactivity_no_mem_embed, current_vision_feats[-1]) + return to_spatial(pix_feat, H, W) + + memory = torch.cat(to_cat_memory, dim=1) + memory_pos = torch.cat(to_cat_memory_pos, dim=1) + + # Expand vision features to num_buckets if memory has more buckets than B + mem_B = memory.shape[0] + x = current_vision_feats[-1] + x_pos = current_vision_pos_embeds[-1] + if x.shape[0] < mem_B: + x = x.expand(mem_B, -1, -1) + x_pos = x_pos.expand(mem_B, -1, -1) + + if len(to_cat_image_feat) > 0: + # Decoupled cross-attention: separate image features from memory + memory_image = cast_to_input(torch.cat(to_cat_image_feat, dim=1), x) + memory_image_pos = cast_to_input(torch.cat(to_cat_image_pos, dim=1), x) + if memory_image.shape[0] < mem_B: + memory_image = memory_image.expand(mem_B, -1, -1) + memory_image_pos = memory_image_pos.expand(mem_B, -1, -1) + pix_feat_with_mem = self.transformer.encoder( + x=x, + memory=cast_to_input(memory, x), + memory_pos=cast_to_input(memory_pos, x), + src_pos=cast_to_input(x_pos, x), + num_k_exclude_rope=num_obj_ptr_tokens, + memory_image=memory_image, + memory_image_pos=memory_image_pos, + ) + else: + pix_feat_with_mem = self.transformer.encoder( + x=x, + memory=memory, + memory_pos=memory_pos, + src_pos=x_pos, + num_k_exclude_rope=num_obj_ptr_tokens, + ) + return to_spatial(pix_feat_with_mem, H, W) + + def _encode_new_memory(self, pix_feat, pred_masks_high_res, object_score_logits, is_mask_from_pts=False, + multiplex_state=None, is_conditioning=False, cond_obj_mask=None): + if is_mask_from_pts: + mask_for_mem = (pred_masks_high_res > 0).to(pix_feat.dtype) + else: + mask_for_mem = torch.sigmoid(pred_masks_high_res) + mask_for_mem.mul_(self.sigmoid_scale_for_mem_enc).add_(self.sigmoid_bias_for_mem_enc) + + # Mux masks: [N_obj, 1, H, W] -> [num_buckets, M, H, W] + mux_masks = multiplex_state.mux(mask_for_mem[:, 0]) + + # Conditioning channel: 1.0 = clean mask (trust it), 0.0 = propagation (noisy) + N_obj = mask_for_mem.shape[0] + cond_values = torch.full((N_obj,), 0.0, device=mask_for_mem.device, dtype=mask_for_mem.dtype) + if is_conditioning: + cond_values[:] = 1.0 + elif cond_obj_mask is not None: + cond_values[cond_obj_mask] = 1.0 + cond_spatial = cond_values.view(-1, 1, 1, 1).expand_as(mask_for_mem[:, 0:1, :, :]).squeeze(1) + mux_cond = multiplex_state.mux(cond_spatial) # [num_buckets, M, H, W] + mux_input = torch.cat([mux_masks, mux_cond], dim=1) # [num_buckets, 2*M, H, W] + + maskmem_out = self.maskmem_backbone(pix_feat, mux_input, skip_mask_sigmoid=True) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + # Add no_obj_embed_spatial for occluded objects + is_obj = (object_score_logits > 0).float() # [N_obj, 1] + mux_is_obj = multiplex_state.mux(is_obj) # [num_buckets, M, 1] + no_obj_embed = cast_to_input(self.no_obj_embed_spatial, maskmem_features) # [M, C] + no_obj_spatial = no_obj_embed.unsqueeze(0)[..., None, None] # [1, M, C, 1, 1] + # Expand and sum across multiplex slots weighted by (1 - is_obj) + alpha = mux_is_obj[..., None, None] # [num_buckets, M, 1, 1, 1] + per_slot_no_obj = ((1 - alpha) * no_obj_spatial).sum(dim=1) # [num_buckets, C, 1, 1] + maskmem_features = maskmem_features + per_slot_no_obj.expand_as(maskmem_features) + + return maskmem_features, maskmem_pos_enc + + def _forward_propagation(self, backbone_features, high_res_features=None, multiplex_state=None): + """Propagation path using the multiplex SAM decoder (no prompts).""" + B = backbone_features.shape[0] + device = backbone_features.device + + # Suppression embeddings from valid object mask + valid_mask = cast_to_input(multiplex_state.get_valid_object_mask().unsqueeze(-1).float(), backbone_features) + output_valid = cast_to_input(self.output_valid_embed, backbone_features).unsqueeze(0) + output_invalid = cast_to_input(self.output_invalid_embed, backbone_features).unsqueeze(0) + extra_embed = valid_mask * output_valid + (1 - valid_mask) * output_invalid + + image_pe = self.image_pe_layer((backbone_features.shape[-2], backbone_features.shape[-1]), device=backbone_features.device) + image_pe = cast_to_input(image_pe, backbone_features) + + masks, iou_pred, sam_tokens_out, object_score_logits = self.sam_mask_decoder( + image_embeddings=backbone_features, image_pe=image_pe, + sparse_prompt_embeddings=torch.empty(B, 0, self.d_model, device=device, dtype=backbone_features.dtype), + dense_prompt_embeddings=torch.zeros(B, self.d_model, *backbone_features.shape[-2:], device=device, dtype=backbone_features.dtype), + high_res_features=high_res_features, multimask_output=True, return_all=True, + extra_per_object_embeddings=extra_embed.expand(B, -1, -1), + ) + # masks: [B=num_buckets, M, T, H, W] + # Demux to per-object: [N_obj, T, H, W] + masks_obj = multiplex_state.demux(masks) + iou_obj = multiplex_state.demux(iou_pred) + score_obj = multiplex_state.demux(object_score_logits) + tokens_obj = multiplex_state.demux(sam_tokens_out) + + # Select best mask by IoU for each object + best_idx = torch.argmax(iou_obj, dim=-1) # [N_obj] + N_obj = masks_obj.shape[0] + obj_range = torch.arange(N_obj, device=device) + low_res_masks = masks_obj[obj_range, best_idx].unsqueeze(1) # [N_obj, 1, H, W] + # Suppress masks for objects with low confidence + is_obj = score_obj > 0 + low_res_masks = torch.where(is_obj[:, :, None, None], low_res_masks, + torch.tensor(NO_OBJ_SCORE, device=device, dtype=low_res_masks.dtype)) + high_res_masks = F.interpolate(low_res_masks.float(), size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + + # Object pointer: compute per-object, mux for storage as [num_buckets, M, C] + sam_token = tokens_obj[:, 0] # [N_obj, C] + obj_ptr = self.obj_ptr_proj(sam_token) + is_obj = (score_obj > 0).float() + no_obj = self.no_obj_ptr_linear(obj_ptr) + obj_ptr = is_obj * obj_ptr + (1 - is_obj) * no_obj + obj_ptr_muxed = multiplex_state.mux(obj_ptr) # [num_buckets, M, C] + + return low_res_masks, high_res_masks, obj_ptr_muxed, score_obj + + def track_step(self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, + feat_sizes, mask_inputs, output_dict, num_frames, point_inputs=None, + interactive_high_res=None, interactive_backbone=None, propagation_high_res=None, + multiplex_state=None, run_mem_encoder=True): + current_out = {} + H, W = feat_sizes[-1] + + if mask_inputs is not None: + # Conditioning frame: use interactive features if available, else propagation + if interactive_backbone is not None: + pix_feat = interactive_backbone + # Add no_mem_embed for interactive path + pix_flat = pix_feat.flatten(2) + bf = pix_flat.permute(0, 2, 1) + cast_to_input(self.interactivity_no_mem_embed, pix_flat) + pix_feat = to_spatial(bf, H, W) + hi_res = interactive_high_res + else: + # Fallback: interactive backbone not available (e.g. called outside track_video). + # Propagation features work but may produce lower-quality conditioning. + pix_feat = to_spatial(current_vision_feats[-1], H, W) + hi_res = propagation_high_res + sam_outputs = self._use_mask_as_output(pix_feat, hi_res, mask_inputs) + elif point_inputs is not None: + # Interactive path: use interactive SAM decoder + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames, + multiplex_state=multiplex_state, + ) + hi_res = interactive_high_res if interactive_high_res is not None else propagation_high_res + num_pts = point_inputs["point_labels"].size(1) + multimask_output = is_init_cond_frame and 0 < num_pts <= 1 + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, point_inputs=point_inputs, + high_res_features=hi_res, multimask_output=multimask_output, + ) + else: + # Propagation path: use multiplex SAM decoder with propagation features + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, output_dict=output_dict, num_frames=num_frames, + multiplex_state=multiplex_state, + ) + sam_outputs = self._forward_propagation(pix_feat_with_mem, propagation_high_res, + multiplex_state=multiplex_state) + + (low_res_masks, high_res_masks, obj_ptr, object_score_logits) = sam_outputs + + # Mux obj_ptr if it came from interactive path (shape [B, C]) vs propagation ([num_buckets, M, C]) + if multiplex_state is not None and obj_ptr.dim() == 2: + obj_ptr = multiplex_state.mux(obj_ptr) # [N_obj, C] -> [num_buckets, M, C] + + # Encode memory (can be deferred with run_mem_encoder=False) + if run_mem_encoder and self.num_maskmem > 0: + pix_feat = to_spatial(current_vision_feats[-1], H, W) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, pred_masks_high_res=high_res_masks, + object_score_logits=object_score_logits, + is_mask_from_pts=(point_inputs is not None), + multiplex_state=multiplex_state, + is_conditioning=(mask_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + # Store propagation image features for decoupled memory attention + current_out["image_features"] = current_vision_feats[-1] # [B, HW, C] + current_out["image_pos_enc"] = current_vision_pos_embeds[-1] # [B, HW, C] + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + current_out["object_score_logits"] = object_score_logits + + return current_out + + def _compute_backbone_frame(self, backbone_fn, frame, frame_idx=None): + vision_feats, vision_pos, feat_sizes, features, trunk_out = _compute_backbone(backbone_fn, frame, frame_idx) + return vision_feats, vision_pos, feat_sizes, list(features[:-1]), trunk_out + + @staticmethod + def _suppress_recently_occluded(low_res_masks, last_occluded, frame_idx, threshold=0.3): + """Suppress overlapping masks for objects that were most recently occluded. + Prevents corrupted masks from occluded objects from contaminating other objects.""" + N_obj = low_res_masks.shape[0] + if N_obj <= 1: + return low_res_masks + binary = low_res_masks[:, 0] > 0 # [N_obj, H, W] + iou = _compute_mask_overlap(low_res_masks[:, 0], low_res_masks[:, 0]) + overlapping = torch.triu(iou >= threshold, diagonal=1) # [N, N] upper triangle + last_occ_i = last_occluded.unsqueeze(1) # [N, 1] + last_occ_j = last_occluded.unsqueeze(0) # [1, N] + # Suppress the more recently occluded object in each overlapping pair + suppress_i = overlapping & (last_occ_i > last_occ_j) & (last_occ_j > -1) + suppress_j = overlapping & (last_occ_j > last_occ_i) & (last_occ_i > -1) + to_suppress = suppress_i.any(dim=1) | suppress_j.any(dim=0) + # Update last_occluded for occluded/suppressed objects + is_empty = ~binary.any(dim=(-1, -2)) + newly_occluded = is_empty | to_suppress + last_occluded[newly_occluded] = frame_idx + # Suppress masks + low_res_masks[to_suppress] = -10.0 + return low_res_masks + + def _deferred_memory_encode(self, current_out, N_obj, vision_feats, feat_sizes, mux_state, device, + cond_obj_mask=None): + """Deferred memory encoding for propagation frames. cond_obj_mask: per-object bool for conditioning.""" + low_res_masks = current_out["pred_masks"] # [N_obj, 1, H_low, W_low] + + if N_obj > 1: + lr = low_res_masks.squeeze(1) # [N_obj, H, W] + max_obj = torch.argmax(lr, dim=0, keepdim=True) + batch_inds = torch.arange(N_obj, device=device)[:, None, None] + pixel_nol = torch.where(max_obj == batch_inds, lr, torch.clamp(lr, max=-10.0)) + area_before = (lr > 0).sum(dim=(-1, -2)).float().clamp(min=1) + area_after = (pixel_nol > 0).sum(dim=(-1, -2)).float() + shrink_ok = (area_after / area_before) >= 0.3 + low_res_masks = torch.where( + shrink_ok[:, None, None, None].expand_as(low_res_masks), + low_res_masks, torch.clamp(low_res_masks, max=-10.0)) + + interpol_size = self.maskmem_backbone.mask_downsampler.interpol_size + mem_masks = F.interpolate(low_res_masks, size=interpol_size, + mode="bilinear", align_corners=False) + + obj_scores = torch.where( + (mem_masks > 0).any(dim=(-1, -2)), 10.0, -10.0) + + pix_feat = to_spatial(vision_feats[-1], feat_sizes[-1][0], feat_sizes[-1][1]) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + pix_feat=pix_feat, pred_masks_high_res=mem_masks, + object_score_logits=obj_scores, + multiplex_state=mux_state, cond_obj_mask=cond_obj_mask) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + + def _add_detected_objects(self, new_masks, mux_state, vision_feats, feat_sizes, current_out): + """Grow MultiplexState with new detections, merge masks, re-encode memory. Modifies current_out.""" + n_old = mux_state.total_valid_entries + mux_state.add_objects(new_masks.shape[0]) + N_obj = mux_state.total_valid_entries + # Stored memory with old bucket counts is padded at read time by _pad_to_buckets + for k in ("pred_masks", "pred_masks_high_res"): + det = F.interpolate(new_masks.unsqueeze(1), size=current_out[k].shape[-2:], + mode="bilinear", align_corners=False) + current_out[k] = torch.cat([current_out[k], det], dim=0) + if self.num_maskmem > 0: + # Mark new objects as conditioning (clean detection masks) so model trusts them + cond_mask = torch.zeros(N_obj, dtype=torch.bool, device=new_masks.device) + cond_mask[n_old:] = True + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, + mux_state, new_masks.device, cond_obj_mask=cond_mask) + + def _condition_with_masks(self, masks, frame_idx, vision_feats, vision_pos, feat_sizes, + high_res_prop, output_dict, N, mux_state, backbone_obj, frame, + trunk_out, threshold=0.5): + """Condition tracker with masks on a frame.""" + mask_input = F.interpolate(masks if masks.dim() == 4 else masks.unsqueeze(1), + size=(self.image_size, self.image_size), mode="bilinear", align_corners=False) + mask_input = (mask_input > threshold).to(masks.dtype) + hi_res = lo_feat = None + if backbone_obj is not None and backbone_obj.multiplex: + _, _, itf, _ = backbone_obj(frame, tracker_mode="interactive", cached_trunk=trunk_out, tracker_only=True) + hi_res, lo_feat = itf[:-1], itf[-1] + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=True, current_vision_feats=vision_feats, + current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=mask_input, + output_dict=output_dict, num_frames=N, interactive_high_res=hi_res, + interactive_backbone=lo_feat, propagation_high_res=high_res_prop, + multiplex_state=mux_state, run_mem_encoder=True) + output_dict["cond_frame_outputs"][frame_idx] = current_out + return current_out + + def _match_and_add_detections(self, det_masks, det_scores, current_out, mux_state, + vision_feats, feat_sizes, device, max_objects=0, + keep_alive=None): + """Match detections against tracked masks, add new objects, recondition degraded tracks. + Updates keep_alive counters: +1 for matched tracks, -1 for unmatched.""" + N_obj = mux_state.total_valid_entries + if det_masks.shape[0] == 0: + if keep_alive is not None: + for i in range(N_obj): + keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1) + return [] + + # Match at low-res (like reference) + trk_masks = current_out["pred_masks"][:, 0] # [N_obj, H_low, W_low] + det_resized = F.interpolate(det_masks.unsqueeze(1), size=trk_masks.shape[-2:], + mode="bilinear", align_corners=False)[:, 0] + overlap = _compute_mask_overlap(det_resized, trk_masks) + + # Update keep_alive and find matched tracks + matched = set() + if overlap.shape[1] > 0: + matched = set((overlap >= 0.5).any(dim=0).nonzero(as_tuple=True)[0].tolist()) + if keep_alive is not None: + for i in range(N_obj): + if i in matched: + keep_alive[i] = min(8, keep_alive.get(i, 0) + 1) + else: + keep_alive[i] = max(-4, keep_alive.get(i, 0) - 1) + + # Recondition: high-confidence detections (>=0.8) with high overlap refresh tracked masks + reconditioned = False + if det_scores is not None and overlap.shape[1] > 0: + HIGH_CONF = 0.8 + for det_idx in range(overlap.shape[0]): + if det_scores[det_idx] < HIGH_CONF: + continue + best_trk = overlap[det_idx].argmax().item() + if overlap[det_idx, best_trk] >= 0.5: + # Replace tracked mask with fresh detection mask + current_out["pred_masks"][best_trk] = det_resized[det_idx].unsqueeze(0) + det_hr = F.interpolate(det_masks[det_idx:det_idx+1].unsqueeze(1), + size=current_out["pred_masks_high_res"].shape[-2:], + mode="bilinear", align_corners=False) + current_out["pred_masks_high_res"][best_trk] = det_hr[0] + reconditioned = True + + # Re-encode memory if any tracks were reconditioned + if reconditioned and self.num_maskmem > 0: + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device) + + # Add new detections (not matching any track) + if max_objects > 0 and N_obj >= max_objects: + return [] + max_overlap = overlap.max(dim=1)[0] if overlap.shape[1] > 0 else torch.zeros(overlap.shape[0], device=device) + new_dets = max_overlap < 0.5 + if new_dets.any(): + if max_objects > 0: + slots = max_objects - N_obj + new_dets = new_dets & (torch.cumsum(new_dets.int(), 0) <= slots) + self._add_detected_objects(det_masks[new_dets], mux_state, + vision_feats, feat_sizes, current_out) + if keep_alive is not None: + for i in range(N_obj, mux_state.total_valid_entries): + keep_alive[i] = 1 + return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item() + return [] + + def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None, + new_det_thresh=0.5, max_objects=0, detect_interval=1, + backbone_obj=None, pbar=None): + """Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits.""" + N, device, dt = images.shape[0], images.device, images.dtype + output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}} + all_masks = [] + idev = comfy.model_management.intermediate_device() + mux_state = None + if initial_masks is not None: + mux_state = MultiplexState(initial_masks.shape[0], self.num_multiplex, device, dt) + obj_scores = [] # per-object detection score (1.0 for initial masks) + keep_alive = {} if detect_fn is not None else None + last_occluded = torch.empty(0, device=device, dtype=torch.long) # per-object last occluded frame + + # Prefetch next frame's backbone on a separate CUDA stream + prefetch = False + backbone_stream = None + if comfy.model_management.is_device_cuda(device): + try: + backbone_stream = torch.cuda.Stream(device=device) + prefetch = True + except RuntimeError: + pass + cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0) + + for frame_idx in tqdm(range(N), desc="tracking"): + vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb + + # Start next frame's backbone on separate stream (overlaps with current frame's work) + if prefetch and frame_idx + 1 < N: + backbone_stream.wait_stream(torch.cuda.current_stream(device)) + with torch.cuda.stream(backbone_stream): + next_bb = self._compute_backbone_frame( + backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + + # Per-frame detection with NMS (skip if no detect_fn, or interval/max not met) + det_masks = torch.empty(0, device=device) + det_scores = None + run_det = (detect_fn is not None + and frame_idx % max(detect_interval, 1) == 0 + and not (max_objects > 0 and mux_state is not None + and mux_state.total_valid_entries >= max_objects)) + if run_det: + det_out = detect_fn(trunk_out) + scores = det_out["scores"][0].sigmoid() + keep = scores > new_det_thresh + det_masks, det_scores = det_out["masks"][0][keep], scores[keep] + if det_masks.shape[0] > 1: + det_masks, det_scores = _nms_masks(det_masks, det_scores) + + if frame_idx == 0 and initial_masks is not None: + current_out = self._condition_with_masks( + initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos, + feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj, + images[frame_idx:frame_idx + 1], trunk_out) + last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) + obj_scores = [1.0] * mux_state.total_valid_entries + if keep_alive is not None: + for i in range(mux_state.total_valid_entries): + keep_alive[i] = 8 + elif mux_state is None or mux_state.total_valid_entries == 0: + if det_masks.shape[0] > 0: + if max_objects > 0: + det_scores = det_scores[:max_objects] + det_masks = det_masks[:max_objects] + mux_state = MultiplexState(det_masks.shape[0], self.num_multiplex, device, dt) + current_out = self._condition_with_masks( + det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop, + output_dict, N, mux_state, backbone_obj, + images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0) + last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long) + obj_scores = det_scores[:mux_state.total_valid_entries].tolist() + if keep_alive is not None: + for i in range(mux_state.total_valid_entries): + keep_alive[i] = 1 + else: + all_masks.append(None) + if pbar is not None: + pbar.update(1) + # Skip to backbone advance at end of loop + if frame_idx + 1 < N: + if prefetch: + torch.cuda.current_stream(device).wait_stream(backbone_stream) + cur_bb = next_bb + else: + cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + continue + else: + N_obj = mux_state.total_valid_entries + current_out = self.track_step( + frame_idx=frame_idx, is_init_cond_frame=False, current_vision_feats=vision_feats, + current_vision_pos_embeds=vision_pos, feat_sizes=feat_sizes, mask_inputs=None, + output_dict=output_dict, num_frames=N, propagation_high_res=high_res_prop, + multiplex_state=mux_state, run_mem_encoder=False) + current_out["pred_masks"] = fill_holes_in_mask_scores( + current_out["pred_masks"], max_area=16) + if last_occluded.shape[0] == N_obj and N_obj > 1: + self._suppress_recently_occluded( + current_out["pred_masks"], last_occluded, frame_idx) + if self.num_maskmem > 0: + self._deferred_memory_encode(current_out, N_obj, vision_feats, feat_sizes, mux_state, device) + output_dict["non_cond_frame_outputs"][frame_idx] = current_out + lookback = max(self.num_maskmem, self.max_obj_ptrs_in_encoder) + for old_idx in list(output_dict["non_cond_frame_outputs"]): + if old_idx < frame_idx - lookback: + del output_dict["non_cond_frame_outputs"][old_idx] + n_before = mux_state.total_valid_entries + new_obj_scores = self._match_and_add_detections(det_masks, det_scores, current_out, mux_state, + vision_feats, feat_sizes, device, max_objects, + keep_alive if run_det else None) + n_added = mux_state.total_valid_entries - n_before + if n_added > 0: + last_occluded = torch.cat([last_occluded, + torch.full((n_added,), -1, device=device, dtype=torch.long)]) + obj_scores.extend(new_obj_scores) + + masks_out = current_out["pred_masks_high_res"][:, 0] + if keep_alive is not None: + for i in range(masks_out.shape[0]): + if keep_alive.get(i, 0) <= 0: + masks_out[i] = NO_OBJ_SCORE + N_obj_now = mux_state.total_valid_entries if mux_state is not None else 0 + if N_obj_now > 0: + all_masks.append(pack_masks(masks_out).to(idev)) + else: + all_masks.append(None) + if pbar is not None: + pbar.update(1) + + # Next frame's backbone + if frame_idx + 1 < N: + if prefetch: + torch.cuda.current_stream(device).wait_stream(backbone_stream) + cur_bb = next_bb + else: + cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1) + + if not all_masks or all(m is None for m in all_masks): + return {"packed_masks": None, "n_frames": N, "scores": []} + + max_obj = max(m.shape[0] for m in all_masks if m is not None) + sample = next(m for m in all_masks if m is not None) + empty_packed = torch.zeros(max_obj, *sample.shape[1:], dtype=torch.uint8, device=sample.device) + for i, m in enumerate(all_masks): + if m is None: + all_masks[i] = empty_packed + elif m.shape[0] < max_obj: + pad = torch.zeros(max_obj - m.shape[0], *m.shape[1:], dtype=torch.uint8, device=m.device) + all_masks[i] = torch.cat([m, pad], dim=0) + return {"packed_masks": torch.stack(all_masks, dim=0), "n_frames": N, "scores": obj_scores} diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9..787ea1145 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -54,6 +54,7 @@ import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model +import comfy.ldm.sam3.detector import comfy.model_management import comfy.patcher_extension @@ -578,8 +579,8 @@ class Stable_Zero123(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None, cc_projection_weight=None, cc_projection_bias=None): super().__init__(model_config, model_type, device=device) self.cc_projection = comfy.ops.manual_cast.Linear(cc_projection_weight.shape[1], cc_projection_weight.shape[0], dtype=self.get_dtype(), device=device) - self.cc_projection.weight.copy_(cc_projection_weight) - self.cc_projection.bias.copy_(cc_projection_bias) + self.cc_projection.weight = torch.nn.Parameter(cc_projection_weight.clone()) + self.cc_projection.bias = torch.nn.Parameter(cc_projection_bias.clone()) def extra_conds(self, **kwargs): out = {} @@ -1974,3 +1975,7 @@ class ErnieImage(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class SAM3(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ca06cdd1e..724a241bf 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -718,6 +718,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "ernie" return dit_config + if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1 + if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys: + dit_config = {} + dit_config["image_model"] = "SAM3" + if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys: + dit_config["image_model"] = "SAM31" + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -873,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal return model_config def unet_prefix_from_state_dict(state_dict): + # SAM3: detector.* and tracker.* at top level, no common prefix + if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict): + return "" + candidates = ["model.diffusion_model.", #ldm/sgm models "model.model.", #audio models "net.", #cosmos diff --git a/comfy/model_management.py b/comfy/model_management.py index bcf1399c4..3b39d6080 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1801,7 +1801,7 @@ def debug_memory_summary(): return torch.cuda.memory.memory_summary() return "" -class InterruptProcessingException(Exception): +class InterruptProcessingException(BaseException): pass interrupt_processing_mutex = threading.RLock() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 93d19d6fe..ee56f8523 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -685,9 +685,9 @@ class ModelPatcher: sd.pop(k) return sd - def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False): + def patch_weight_to_device(self, key, device_to=None, inplace_update=False, return_weight=False, force_cast=False): weight, set_func, convert_func = get_key_weight(self.model, key) - if key not in self.patches: + if key not in self.patches and not force_cast: return weight inplace_update = self.weight_inplace_update or inplace_update @@ -695,7 +695,7 @@ class ModelPatcher: if key not in self.backup and not return_weight: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) - temp_dtype = comfy.model_management.lora_compute_dtype(device_to) + temp_dtype = comfy.model_management.lora_compute_dtype(device_to) if key in self.patches else None if device_to is not None: temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) else: @@ -703,9 +703,10 @@ class ModelPatcher: if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) - out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) + out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key) if key in self.patches else temp_weight if set_func is None: - out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) + if key in self.patches: + out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype, seed=comfy.utils.string_to_seed(key)) if return_weight: return out_weight elif inplace_update: @@ -1584,7 +1585,7 @@ class ModelPatcherDynamic(ModelPatcher): key = key_param_name_to_key(n, param_key) if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) - self.patch_weight_to_device(key, device_to=device_to) + self.patch_weight_to_device(key, device_to=device_to, force_cast=True) weight, _, _ = get_key_weight(self.model, key) if weight is not None: self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() @@ -1609,6 +1610,10 @@ class ModelPatcherDynamic(ModelPatcher): m._v = vbar.alloc(v_weight_size) allocated_size += v_weight_size + for param in params: + if param not in ("weight", "bias"): + force_load_param(self, param, device_to) + else: for param in params: key = key_param_name_to_key(n, param) diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..736fe35de 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -12,6 +12,7 @@ from .ldm.cascade.stage_c_coder import StageC_coder from .ldm.audio.autoencoder import AudioOobleckVAE import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder +import comfy.ldm.lightricks.vae.audio_vae import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 @@ -805,6 +806,24 @@ class VAE: self.downscale_index_formula = (4, 8, 8) self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + elif "vocoder.resblocks.0.convs1.0.weight" in sd or "vocoder.vocoder.resblocks.0.convs1.0.weight" in sd: # LTX Audio + sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder."}) + self.first_stage_model = comfy.ldm.lightricks.vae.audio_vae.AudioVAE(metadata=metadata) + self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype) + self.latent_channels = self.first_stage_model.latent_channels + self.audio_sample_rate_output = self.first_stage_model.output_sample_rate + self.autoencoder = self.first_stage_model.autoencoder # TODO: remove hack for ltxv custom nodes + self.output_channels = 2 + self.pad_channel_value = "replicate" + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 + self.latent_dim = 2 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.float32] + self.disable_offload = True + self.extra_1d_channel = 16 else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 58d4ce731..8886f32d5 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1781,6 +1781,57 @@ class ErnieImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage] +class SAM3(supported_models_base.BASE): + unet_config = {"image_model": "SAM3"} + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + text_encoder_key_prefix = ["detector.backbone.language_backbone."] + unet_extra_prefix = "" + + def process_clip_state_dict(self, state_dict): + clip_keys = getattr(self, "_clip_stash", {}) + clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True) + clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.") + return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")} + + def process_unet_state_dict(self, state_dict): + self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k} + # SAM3.1: remap tracker.model.* -> tracker.* + for k in list(state_dict.keys()): + if k.startswith("tracker.model."): + state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k) + # SAM3.1: remove per-block freqs_cis buffers (computed dynamically) + for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]: + state_dict.pop(k) + # Split fused QKV projections + for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]: + t = state_dict.pop(k) + base, suffix = k.rsplit(".in_proj_", 1) + s = ".weight" if suffix == "weight" else ".bias" + d = t.shape[0] // 3 + state_dict[base + ".q_proj" + s] = t[:d] + state_dict[base + ".k_proj" + s] = t[d:2*d] + state_dict[base + ".v_proj" + s] = t[2*d:] + # Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer + for k in list(state_dict.keys()): + if "sam_mask_decoder.transformer." not in k: + continue + new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.") + if new_k != k: + state_dict[new_k] = state_dict.pop(k) + return state_dict + + def get_model(self, state_dict, prefix="", device=None): + return model_base.SAM3(self, device=device) + + def clip_target(self, state_dict={}): + import comfy.text_encoders.sam3_clip + return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper) + + +class SAM31(SAM3): + unet_config = {"image_model": "SAM31"} + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31] models += [SVD_img2vid] diff --git a/comfy/text_encoders/sam3_clip.py b/comfy/text_encoders/sam3_clip.py new file mode 100644 index 000000000..11cb7d9db --- /dev/null +++ b/comfy/text_encoders/sam3_clip.py @@ -0,0 +1,97 @@ +import re +from comfy import sd1_clip + +SAM3_CLIP_CONFIG = { + "architectures": ["CLIPTextModel"], + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "intermediate_size": 4096, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "max_position_embeddings": 32, + "projection_dim": 512, + "vocab_size": 49408, + "layer_norm_eps": 1e-5, + "eos_token_id": 49407, +} + + +class SAM3ClipModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options) + + +class SAM3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data) + self.disable_weights = True + + +def _parse_prompts(text): + """Split comma-separated prompts with optional :N max detections per category""" + text = text.replace("(", "").replace(")", "") + parts = [p.strip() for p in text.split(",") if p.strip()] + result = [] + for part in parts: + m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part) + if m: + text_part = m.group(1).strip() + val = m.group(2) + max_det = max(1, round(float(val))) + result.append((text_part, max_det)) + else: + result.append((part, 1)) + return result + + +class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip") + + def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): + parsed = _parse_prompts(text) + if len(parsed) <= 1 and (not parsed or parsed[0][1] == 1): + return super().tokenize_with_weights(text, return_word_ids, **kwargs) + # Tokenize each prompt part separately, store per-part batches and metadata + inner = getattr(self, self.clip) + per_prompt = [] + for prompt_text, max_det in parsed: + batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs) + per_prompt.append((batches, max_det)) + # Main output uses first prompt's tokens (for compatibility) + out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt} + return out + + +class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip") + + def encode_token_weights(self, token_weight_pairs): + per_prompt = token_weight_pairs.pop("sam3_per_prompt", None) + if per_prompt is None: + return super().encode_token_weights(token_weight_pairs) + + # Encode each prompt separately, pack into extra dict + inner = getattr(self, self.clip) + multi_cond = [] + first_pooled = None + for batches, max_det in per_prompt: + out = inner.encode_token_weights(batches) + cond, pooled = out[0], out[1] + extra = out[2] if len(out) > 2 else {} + if first_pooled is None: + first_pooled = pooled + multi_cond.append({ + "cond": cond, + "attention_mask": extra.get("attention_mask"), + "max_detections": max_det, + }) + + # Return first prompt as main (for non-SAM3 consumers), all prompts in metadata + main = multi_cond[0] + main_extra = {} + if main["attention_mask"] is not None: + main_extra["attention_mask"] = main["attention_mask"] + main_extra["sam3_multi_cond"] = multi_cond + return (main["cond"], first_pooled, main_extra) diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py index 16d4acfd1..dc33533cc 100644 --- a/comfy_api/input/__init__.py +++ b/comfy_api/input/__init__.py @@ -9,6 +9,7 @@ from comfy_api.latest._input import ( CurveInput, MonotoneCubicCurve, LinearCurve, + RangeInput, ) __all__ = [ @@ -21,4 +22,5 @@ __all__ = [ "CurveInput", "MonotoneCubicCurve", "LinearCurve", + "RangeInput", ] diff --git a/comfy_api/latest/_input/__init__.py b/comfy_api/latest/_input/__init__.py index 05cd3d40a..f0229717e 100644 --- a/comfy_api/latest/_input/__init__.py +++ b/comfy_api/latest/_input/__init__.py @@ -1,5 +1,6 @@ from .basic_types import ImageInput, AudioInput, MaskInput, LatentInput from .curve_types import CurvePoint, CurveInput, MonotoneCubicCurve, LinearCurve +from .range_types import RangeInput from .video_types import VideoInput __all__ = [ @@ -12,4 +13,5 @@ __all__ = [ "CurveInput", "MonotoneCubicCurve", "LinearCurve", + "RangeInput", ] diff --git a/comfy_api/latest/_input/range_types.py b/comfy_api/latest/_input/range_types.py new file mode 100644 index 000000000..f4c5cb290 --- /dev/null +++ b/comfy_api/latest/_input/range_types.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import logging +import math +import numpy as np + +logger = logging.getLogger(__name__) + + +class RangeInput: + """Represents a levels/range adjustment: input range [min, max] with + optional midpoint (gamma control). + + Generates a 1D LUT identical to GIMP's levels mapping: + 1. Normalize input to [0, 1] using [min, max] + 2. Apply gamma correction: pow(value, 1/gamma) + 3. Clamp to [0, 1] + + The midpoint field is a position in [0, 1] representing where the + midtone falls within [min, max]. It maps to gamma via: + gamma = -log2(midpoint) + So midpoint=0.5 → gamma=1.0 (linear). + """ + + def __init__(self, min_val: float, max_val: float, midpoint: float | None = None): + self.min_val = min_val + self.max_val = max_val + self.midpoint = midpoint + + @staticmethod + def from_raw(data) -> RangeInput: + if isinstance(data, RangeInput): + return data + if isinstance(data, dict): + return RangeInput( + min_val=float(data.get("min", 0.0)), + max_val=float(data.get("max", 1.0)), + midpoint=float(data["midpoint"]) if data.get("midpoint") is not None else None, + ) + raise TypeError(f"Cannot convert {type(data)} to RangeInput") + + def to_lut(self, size: int = 256) -> np.ndarray: + """Generate a float64 lookup table mapping [0, 1] input through this + levels adjustment. + + The LUT maps normalized input values (0..1) to output values (0..1), + matching the GIMP levels formula. + """ + xs = np.linspace(0.0, 1.0, size, dtype=np.float64) + + in_range = self.max_val - self.min_val + if abs(in_range) < 1e-10: + return np.where(xs >= self.min_val, 1.0, 0.0).astype(np.float64) + + # Normalize: map [min, max] → [0, 1] + result = (xs - self.min_val) / in_range + result = np.clip(result, 0.0, 1.0) + + # Gamma correction from midpoint + if self.midpoint is not None and self.midpoint > 0 and self.midpoint != 0.5: + gamma = max(-math.log2(self.midpoint), 0.001) + inv_gamma = 1.0 / gamma + mask = result > 0 + result[mask] = np.power(result[mask], inv_gamma) + + return result + + def __repr__(self) -> str: + mid = f", midpoint={self.midpoint}" if self.midpoint is not None else "" + return f"RangeInput(min={self.min_val}, max={self.max_val}{mid})" diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 1b4993aa7..eb4d3701d 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -240,19 +240,34 @@ class VideoFromFile(VideoInput): start_time = self.__start_time # Get video frames frames = [] + alphas = None start_pts = int(start_time / video_stream.time_base) end_pts = int((start_time + self.__duration) / video_stream.time_base) container.seek(start_pts, stream=video_stream) + image_format = 'gbrpf32le' for frame in container.decode(video_stream): + if alphas is None: + for comp in frame.format.components: + if comp.is_alpha: + alphas = [] + image_format = 'gbrapf32le' + break + if frame.pts < start_pts: continue if self.__duration and frame.pts >= end_pts: break - img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3) - img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3) - frames.append(img) - images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0) + img = frame.to_ndarray(format=image_format) # shape: (H, W, 4) + if alphas is None: + frames.append(torch.from_numpy(img)) + else: + frames.append(torch.from_numpy(img[..., :-1])) + alphas.append(torch.from_numpy(img[..., -1:])) + + images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 0, 0, 3) + if alphas is not None: + alphas = torch.stack(alphas) if len(alphas) > 0 else torch.zeros(0, 0, 0, 1) # Get frame rate frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1) @@ -295,7 +310,7 @@ class VideoFromFile(VideoInput): }) metadata = container.metadata - return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata) + return VideoComponents(images=images, alpha=alphas, audio=audio, frame_rate=frame_rate, metadata=metadata) def get_components(self) -> VideoComponents: if isinstance(self.__file, io.BytesIO): diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index fdeffea2d..4942ed46c 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1266,6 +1266,43 @@ class Histogram(ComfyTypeIO): Type = list[int] +@comfytype(io_type="RANGE") +class Range(ComfyTypeIO): + from comfy_api.input import RangeInput + if TYPE_CHECKING: + Type = RangeInput + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: dict=None, + display: str=None, + gradient_stops: list=None, + show_midpoint: bool=None, + midpoint_scale: str=None, + value_min: float=None, + value_max: float=None, + advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = {"min": 0.0, "max": 1.0} + self.display = display + self.gradient_stops = gradient_stops + self.show_midpoint = show_midpoint + self.midpoint_scale = midpoint_scale + self.value_min = value_min + self.value_max = value_max + + def as_dict(self): + return super().as_dict() | prune_dict({ + "display": self.display, + "gradient_stops": self.gradient_stops, + "show_midpoint": self.show_midpoint, + "midpoint_scale": self.midpoint_scale, + "value_min": self.value_min, + "value_max": self.value_max, + }) + + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -2276,5 +2313,6 @@ __all__ = [ "BoundingBox", "Curve", "Histogram", + "Range", "NodeReplace", ] diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py index fd3b5a510..c92477f08 100644 --- a/comfy_api/latest/_util/video_types.py +++ b/comfy_api/latest/_util/video_types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from fractions import Fraction from typing import Optional -from .._input import ImageInput, AudioInput +from .._input import ImageInput, AudioInput, MaskInput class VideoCodec(str, Enum): AUTO = "auto" @@ -48,5 +48,4 @@ class VideoComponents: frame_rate: Fraction audio: Optional[AudioInput] = None metadata: Optional[dict] = None - - + alpha: Optional[MaskInput] = None diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 3755323ac..eafabbefe 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -122,6 +122,41 @@ class TaskStatusResponse(BaseModel): usage: TaskStatusUsage | None = Field(None) +class GetAssetResponse(BaseModel): + id: str = Field(...) + name: str | None = Field(None) + url: str | None = Field(None) + asset_type: str = Field(...) + group_id: str = Field(...) + status: str = Field(...) + error: TaskStatusError | None = Field(None) + + +class SeedanceCreateVisualValidateSessionResponse(BaseModel): + session_id: str = Field(...) + h5_link: str = Field(...) + + +class SeedanceGetVisualValidateSessionResponse(BaseModel): + session_id: str = Field(...) + status: str = Field(...) + group_id: str | None = Field(None) + error_code: str | None = Field(None) + error_message: str | None = Field(None) + + +class SeedanceCreateAssetRequest(BaseModel): + group_id: str = Field(...) + url: str = Field(...) + asset_type: str = Field(...) + name: str | None = Field(None, max_length=64) + project_name: str | None = Field(None) + + +class SeedanceCreateAssetResponse(BaseModel): + asset_id: str = Field(...) + + # Dollars per 1K tokens, keyed by (model_id, has_video_input). SEEDANCE2_PRICE_PER_1K_TOKENS = { ("dreamina-seedance-2-0-260128", False): 0.007, @@ -158,10 +193,17 @@ RECOMMENDED_PRESETS_SEEDREAM_4 = [ ("Custom", None, None), ] -# Seedance 2.0 reference video pixel count limits per model. +# Seedance 2.0 reference video pixel count limits per model and output resolution. SEEDANCE2_REF_VIDEO_PIXEL_LIMITS = { - "dreamina-seedance-2-0-260128": {"min": 409_600, "max": 927_408}, - "dreamina-seedance-2-0-fast-260128": {"min": 409_600, "max": 927_408}, + "dreamina-seedance-2-0-260128": { + "480p": {"min": 409_600, "max": 927_408}, + "720p": {"min": 409_600, "max": 927_408}, + "1080p": {"min": 409_600, "max": 2_073_600}, + }, + "dreamina-seedance-2-0-fast-260128": { + "480p": {"min": 409_600, "max": 927_408}, + "720p": {"min": 409_600, "max": 927_408}, + }, } # The time in this dictionary are given for 10 seconds duration. diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 429c32444..de192c5ac 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,5 +1,6 @@ import logging import math +import re import torch from typing_extensions import override @@ -11,9 +12,14 @@ from comfy_api_nodes.apis.bytedance import ( SEEDANCE2_PRICE_PER_1K_TOKENS, SEEDANCE2_REF_VIDEO_PIXEL_LIMITS, VIDEO_TASKS_EXECUTION_TIME, + GetAssetResponse, Image2VideoTaskCreationRequest, ImageTaskCreationResponse, Seedance2TaskCreationRequest, + SeedanceCreateAssetRequest, + SeedanceCreateAssetResponse, + SeedanceCreateVisualValidateSessionResponse, + SeedanceGetVisualValidateSessionResponse, Seedream4Options, Seedream4TaskCreationRequest, TaskAudioContent, @@ -35,6 +41,7 @@ from comfy_api_nodes.util import ( get_number_of_images, image_tensor_pair_to_batch, poll_op, + resize_video_to_pixel_budget, sync_op, upload_audio_to_comfyapi, upload_image_to_comfyapi, @@ -43,10 +50,16 @@ from comfy_api_nodes.util import ( validate_image_aspect_ratio, validate_image_dimensions, validate_string, + validate_video_dimensions, + validate_video_duration, ) +from server import PromptServer BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" +_VERIFICATION_POLL_TIMEOUT_SEC = 120 +_VERIFICATION_POLL_INTERVAL_SEC = 3 + SEEDREAM_MODELS = { "seedream 5.0 lite": "seedream-5-0-260128", "seedream-4-5-251128": "seedream-4-5-251128", @@ -69,9 +82,12 @@ DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-2504 logger = logging.getLogger(__name__) -def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> None: - """Validate reference video pixel count against Seedance 2.0 model limits.""" - limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id) +def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: str, index: int) -> None: + """Validate reference video pixel count against Seedance 2.0 model limits for the selected resolution.""" + model_limits = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id) + if not model_limits: + return + limits = model_limits.get(resolution) if not limits: return try: @@ -92,6 +108,169 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, index: int) -> ) +async def _resolve_reference_assets( + cls: type[IO.ComfyNode], + asset_ids: list[str], +) -> tuple[dict[str, str], dict[str, str], dict[str, str]]: + """Look up each asset, validate Active status, group by asset_type. + + Returns (image_assets, video_assets, audio_assets), each mapping asset_id -> "asset://". + """ + image_assets: dict[str, str] = {} + video_assets: dict[str, str] = {} + audio_assets: dict[str, str] = {} + for i, raw_id in enumerate(asset_ids, 1): + asset_id = (raw_id or "").strip() + if not asset_id: + continue + result = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"), + response_model=GetAssetResponse, + ) + if result.status != "Active": + extra = f" {result.error.code}: {result.error.message}" if result.error else "" + raise ValueError(f"Reference asset {i} (Id={asset_id}) is not Active (Status={result.status}).{extra}") + asset_uri = f"asset://{asset_id}" + if result.asset_type == "Image": + image_assets[asset_id] = asset_uri + elif result.asset_type == "Video": + video_assets[asset_id] = asset_uri + elif result.asset_type == "Audio": + audio_assets[asset_id] = asset_uri + return image_assets, video_assets, audio_assets + + +_ASSET_REF_RE = re.compile(r"\basset ?(\d{1,2})\b", re.IGNORECASE) + + +def _build_asset_labels( + reference_assets: dict[str, str], + image_asset_uris: dict[str, str], + video_asset_uris: dict[str, str], + audio_asset_uris: dict[str, str], + n_reference_images: int, + n_reference_videos: int, + n_reference_audios: int, +) -> dict[int, str]: + """Map asset slot number (from 'asset_N' keys) to its positional label. + + Asset entries are appended to `content` after the reference_images/videos/audios, + so their 1-indexed labels continue from the count of existing same-type refs: + one reference_images entry + one Image-type asset -> asset labelled "Image 2". + """ + image_n = n_reference_images + video_n = n_reference_videos + audio_n = n_reference_audios + labels: dict[int, str] = {} + for slot_key, raw_id in reference_assets.items(): + asset_id = (raw_id or "").strip() + if not asset_id: + continue + try: + slot_num = int(slot_key.rsplit("_", 1)[-1]) + except ValueError: + continue + if asset_id in image_asset_uris: + image_n += 1 + labels[slot_num] = f"Image {image_n}" + elif asset_id in video_asset_uris: + video_n += 1 + labels[slot_num] = f"Video {video_n}" + elif asset_id in audio_asset_uris: + audio_n += 1 + labels[slot_num] = f"Audio {audio_n}" + return labels + + +def _rewrite_asset_refs(prompt: str, labels: dict[int, str]) -> str: + """Case-insensitively replace 'assetNN' (1-2 digit) tokens with their labels.""" + if not labels: + return prompt + + def _sub(m: "re.Match[str]") -> str: + return labels.get(int(m.group(1)), m.group(0)) + + return _ASSET_REF_RE.sub(_sub, prompt) + + +async def _obtain_group_id_via_h5_auth(cls: type[IO.ComfyNode]) -> str: + session = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/visual-validate/sessions", method="POST"), + response_model=SeedanceCreateVisualValidateSessionResponse, + ) + logger.warning("Seedance authentication required. Open link: %s", session.h5_link) + + h5_text = f"Open this link in your browser and complete face verification:\n\n{session.h5_link}" + + result = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/seedance/visual-validate/sessions/{session.session_id}"), + response_model=SeedanceGetVisualValidateSessionResponse, + status_extractor=lambda r: r.status, + completed_statuses=["completed"], + failed_statuses=["failed"], + poll_interval=_VERIFICATION_POLL_INTERVAL_SEC, + max_poll_attempts=(_VERIFICATION_POLL_TIMEOUT_SEC // _VERIFICATION_POLL_INTERVAL_SEC) - 1, + estimated_duration=_VERIFICATION_POLL_TIMEOUT_SEC - 1, + extra_text=h5_text, + ) + + if not result.group_id: + raise RuntimeError(f"Seedance session {session.session_id} completed without a group_id") + + logger.warning("Seedance authentication complete. New GroupId: %s", result.group_id) + PromptServer.instance.send_progress_text( + f"Authentication complete. New GroupId: {result.group_id}", cls.hidden.unique_id + ) + return result.group_id + + +async def _resolve_group_id(cls: type[IO.ComfyNode], group_id: str) -> str: + if group_id and group_id.strip(): + return group_id.strip() + return await _obtain_group_id_via_h5_auth(cls) + + +async def _create_seedance_asset( + cls: type[IO.ComfyNode], + *, + group_id: str, + url: str, + name: str, + asset_type: str, +) -> str: + req = SeedanceCreateAssetRequest( + group_id=group_id, + url=url, + asset_type=asset_type, + name=name or None, + ) + result = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/assets", method="POST"), + response_model=SeedanceCreateAssetResponse, + data=req, + ) + return result.asset_id + + +async def _wait_for_asset_active(cls: type[IO.ComfyNode], asset_id: str, group_id: str) -> GetAssetResponse: + """Poll the newly created asset until its status becomes Active.""" + return await poll_op( + cls, + ApiEndpoint(path=f"/proxy/seedance/assets/{asset_id}"), + response_model=GetAssetResponse, + status_extractor=lambda r: r.status, + completed_statuses=["Active"], + failed_statuses=["Failed"], + poll_interval=5, + max_poll_attempts=1200, + extra_text=f"Waiting for asset pre-processing...\n\nasset_id: {asset_id}\n\ngroup_id: {group_id}", + ) + + def _seedance2_price_extractor(model_id: str, has_video_input: bool): """Returns a price_extractor closure for Seedance 2.0 poll_op.""" rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) @@ -1224,12 +1403,27 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): IO.Image.Input( "first_frame", tooltip="First frame image for the video.", + optional=True, ), IO.Image.Input( "last_frame", tooltip="Last frame image for the video.", optional=True, ), + IO.String.Input( + "first_frame_asset_id", + default="", + tooltip="Seedance asset_id to use as the first frame. " + "Mutually exclusive with the first_frame image input.", + optional=True, + ), + IO.String.Input( + "last_frame_asset_id", + default="", + tooltip="Seedance asset_id to use as the last frame. " + "Mutually exclusive with the last_frame image input.", + optional=True, + ), IO.Int.Input( "seed", default=0, @@ -1282,24 +1476,54 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): async def execute( cls, model: dict, - first_frame: Input.Image, seed: int, watermark: bool, + first_frame: Input.Image | None = None, last_frame: Input.Image | None = None, + first_frame_asset_id: str = "", + last_frame_asset_id: str = "", ) -> IO.NodeOutput: validate_string(model["prompt"], strip_whitespace=True, min_length=1) model_id = SEEDANCE_MODELS[model["model"]] + first_frame_asset_id = first_frame_asset_id.strip() + last_frame_asset_id = last_frame_asset_id.strip() + + if first_frame is not None and first_frame_asset_id: + raise ValueError("Provide only one of first_frame or first_frame_asset_id, not both.") + if first_frame is None and not first_frame_asset_id: + raise ValueError("Either first_frame or first_frame_asset_id is required.") + if last_frame is not None and last_frame_asset_id: + raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.") + + asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a] + image_assets: dict[str, str] = {} + if asset_ids_to_resolve: + image_assets, _, _ = await _resolve_reference_assets(cls, asset_ids_to_resolve) + for aid in asset_ids_to_resolve: + if aid not in image_assets: + raise ValueError(f"Asset {aid} is not an Image asset.") + + if first_frame_asset_id: + first_frame_url = image_assets[first_frame_asset_id] + else: + first_frame_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.") + content: list[TaskTextContent | TaskImageContent] = [ TaskTextContent(text=model["prompt"]), TaskImageContent( - image_url=TaskImageContentUrl( - url=await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame.") - ), + image_url=TaskImageContentUrl(url=first_frame_url), role="first_frame", ), ] - if last_frame is not None: + if last_frame_asset_id: + content.append( + TaskImageContent( + image_url=TaskImageContentUrl(url=image_assets[last_frame_asset_id]), + role="last_frame", + ), + ) + elif last_frame is not None: content.append( TaskImageContent( image_url=TaskImageContentUrl( @@ -1373,6 +1597,32 @@ def _seedance2_reference_inputs(resolutions: list[str]): min=0, ), ), + IO.Boolean.Input( + "auto_downscale", + default=False, + advanced=True, + optional=True, + tooltip="Automatically downscale reference videos that exceed the model's pixel budget " + "for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.", + ), + IO.Autogrow.Input( + "reference_assets", + template=IO.Autogrow.TemplateNames( + IO.String.Input("reference_asset"), + names=[ + "asset_1", + "asset_2", + "asset_3", + "asset_4", + "asset_5", + "asset_6", + "asset_7", + "asset_8", + "asset_9", + ], + min=0, + ), + ), ] @@ -1474,16 +1724,47 @@ class ByteDance2ReferenceNode(IO.ComfyNode): reference_images = model.get("reference_images", {}) reference_videos = model.get("reference_videos", {}) reference_audios = model.get("reference_audios", {}) + reference_assets = model.get("reference_assets", {}) - if not reference_images and not reference_videos: - raise ValueError("At least one reference image or video is required.") + reference_image_assets, reference_video_assets, reference_audio_assets = await _resolve_reference_assets( + cls, list(reference_assets.values()) + ) + + if not reference_images and not reference_videos and not reference_image_assets and not reference_video_assets: + raise ValueError("At least one reference image or video or asset is required.") + + total_images = len(reference_images) + len(reference_image_assets) + if total_images > 9: + raise ValueError( + f"Too many reference images: {total_images} " + f"(images={len(reference_images)}, image assets={len(reference_image_assets)}). Maximum is 9." + ) + total_videos = len(reference_videos) + len(reference_video_assets) + if total_videos > 3: + raise ValueError( + f"Too many reference videos: {total_videos} " + f"(videos={len(reference_videos)}, video assets={len(reference_video_assets)}). Maximum is 3." + ) + total_audios = len(reference_audios) + len(reference_audio_assets) + if total_audios > 3: + raise ValueError( + f"Too many reference audios: {total_audios} " + f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3." + ) model_id = SEEDANCE_MODELS[model["model"]] - has_video_input = len(reference_videos) > 0 + has_video_input = total_videos > 0 + + if model.get("auto_downscale") and reference_videos: + max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max") + if max_px: + for key in reference_videos: + reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px) + total_video_duration = 0.0 for i, key in enumerate(reference_videos, 1): video = reference_videos[key] - _validate_ref_video_pixels(video, model_id, i) + _validate_ref_video_pixels(video, model_id, model["resolution"], i) try: dur = video.get_duration() if dur < 1.8: @@ -1506,8 +1787,19 @@ class ByteDance2ReferenceNode(IO.ComfyNode): if total_audio_duration > 15.1: raise ValueError(f"Total reference audio duration is {total_audio_duration:.1f}s. Maximum is 15.1 seconds.") + asset_labels = _build_asset_labels( + reference_assets, + reference_image_assets, + reference_video_assets, + reference_audio_assets, + len(reference_images), + len(reference_videos), + len(reference_audios), + ) + prompt_text = _rewrite_asset_refs(model["prompt"], asset_labels) + content: list[TaskTextContent | TaskImageContent | TaskVideoContent | TaskAudioContent] = [ - TaskTextContent(text=model["prompt"]), + TaskTextContent(text=prompt_text), ] for i, key in enumerate(reference_images, 1): content.append( @@ -1548,6 +1840,21 @@ class ByteDance2ReferenceNode(IO.ComfyNode): ), ), ) + for url in reference_image_assets.values(): + content.append( + TaskImageContent( + image_url=TaskImageContentUrl(url=url), + role="reference_image", + ), + ) + for url in reference_video_assets.values(): + content.append( + TaskVideoContent(video_url=TaskVideoContentUrl(url=url)), + ) + for url in reference_audio_assets.values(): + content.append( + TaskAudioContent(audio_url=TaskAudioContentUrl(url=url)), + ) initial_response = await sync_op( cls, ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), @@ -1602,6 +1909,156 @@ async def process_video_task( return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) +class ByteDanceCreateImageAsset(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ByteDanceCreateImageAsset", + display_name="ByteDance Create Image Asset", + category="api node/image/ByteDance", + description=( + "Create a Seedance 2.0 personal image asset. Uploads the input image and " + "registers it in the given asset group. If group_id is empty, runs a real-person " + "H5 authentication flow to create a new group before adding the asset." + ), + inputs=[ + IO.Image.Input("image", tooltip="Image to register as a personal asset."), + IO.String.Input( + "group_id", + default="", + tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the " + "same person. Leave empty to run real-person authentication in the browser and create a new group.", + ), + # IO.String.Input( + # "name", + # default="", + # tooltip="Asset name (up to 64 characters).", + # ), + ], + outputs=[ + IO.String.Output(display_name="asset_id"), + IO.String.Output(display_name="group_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + # is_api_node=True, + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + group_id: str = "", + # name: str = "", + ) -> IO.NodeOutput: + # if len(name) > 64: + # raise ValueError("Name of asset can not be greater then 64 symbols") + validate_image_dimensions(image, min_width=300, max_width=6000, min_height=300, max_height=6000) + validate_image_aspect_ratio(image, min_ratio=(0.4, 1), max_ratio=(2.5, 1)) + resolved_group = await _resolve_group_id(cls, group_id) + asset_id = await _create_seedance_asset( + cls, + group_id=resolved_group, + url=await upload_image_to_comfyapi(cls, image), + name="", + asset_type="Image", + ) + await _wait_for_asset_active(cls, asset_id, resolved_group) + PromptServer.instance.send_progress_text( + f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n" + f"group_id: {resolved_group}", + cls.hidden.unique_id, + ) + return IO.NodeOutput(asset_id, resolved_group) + + +class ByteDanceCreateVideoAsset(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="ByteDanceCreateVideoAsset", + display_name="ByteDance Create Video Asset", + category="api node/video/ByteDance", + description=( + "Create a Seedance 2.0 personal video asset. Uploads the input video and " + "registers it in the given asset group. If group_id is empty, runs a real-person " + "H5 authentication flow to create a new group before adding the asset." + ), + inputs=[ + IO.Video.Input("video", tooltip="Video to register as a personal asset."), + IO.String.Input( + "group_id", + default="", + tooltip="Reuse an existing Seedance asset group ID to skip repeated human verification for the " + "same person. Leave empty to run real-person authentication in the browser and create a new group.", + ), + # IO.String.Input( + # "name", + # default="", + # tooltip="Asset name (up to 64 characters).", + # ), + ], + outputs=[ + IO.String.Output(display_name="asset_id"), + IO.String.Output(display_name="group_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + # is_api_node=True, + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + group_id: str = "", + # name: str = "", + ) -> IO.NodeOutput: + # if len(name) > 64: + # raise ValueError("Name of asset can not be greater then 64 symbols") + validate_video_duration(video, min_duration=2, max_duration=15) + validate_video_dimensions(video, min_width=300, max_width=6000, min_height=300, max_height=6000) + + w, h = video.get_dimensions() + if h > 0: + ratio = w / h + if not (0.4 <= ratio <= 2.5): + raise ValueError(f"Asset video aspect ratio (W/H) must be in [0.4, 2.5], got {ratio:.3f} ({w}x{h}).") + pixels = w * h + if not (409_600 <= pixels <= 927_408): + raise ValueError( + f"Asset video total pixels (W×H) must be in [409600, 927408], " f"got {pixels:,} ({w}x{h})." + ) + + fps = float(video.get_frame_rate()) + if not (24 <= fps <= 60): + raise ValueError(f"Asset video FPS must be in [24, 60], got {fps:.2f}.") + + resolved_group = await _resolve_group_id(cls, group_id) + asset_id = await _create_seedance_asset( + cls, + group_id=resolved_group, + url=await upload_video_to_comfyapi(cls, video), + name="", + asset_type="Video", + ) + await _wait_for_asset_active(cls, asset_id, resolved_group) + PromptServer.instance.send_progress_text( + f"Please save the asset_id and group_id for reuse.\n\nasset_id: {asset_id}\n\n" + f"group_id: {resolved_group}", + cls.hidden.unique_id, + ) + return IO.NodeOutput(asset_id, resolved_group) + + class ByteDanceExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1615,6 +2072,8 @@ class ByteDanceExtension(ComfyExtension): ByteDance2TextToVideoNode, ByteDance2FirstLastFrameNode, ByteDance2ReferenceNode, + ByteDanceCreateImageAsset, + ByteDanceCreateVideoAsset, ] diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 9a37ccc53..709b3726c 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -276,6 +276,7 @@ async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusRe cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), response_model=TaskStatusResponse, + max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -862,7 +863,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): ), IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider), - IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True), IO.DynamicCombo.Input( "storyboards", options=[ @@ -904,12 +905,13 @@ class OmniProTextToVideoNode(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $res := widgets.resolution; + $mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro"); $isV3 := $contains(widgets.model_name, "v3"); $audio := $isV3 and widgets.generate_audio; $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; + ? {"std": 0.112, "pro": 0.14, "4k": 0.42} + : {"std": 0.084, "pro": 0.112, "4k": 0.42}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -934,6 +936,8 @@ class OmniProTextToVideoNode(IO.ComfyNode): raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.") if generate_audio: raise ValueError("kling-video-o1 does not support audio generation.") + if resolution == "4k": + raise ValueError("kling-video-o1 does not support 4k resolution.") stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" if stories_enabled and model_name == "kling-video-o1": raise ValueError("kling-video-o1 does not support storyboards.") @@ -963,6 +967,12 @@ class OmniProTextToVideoNode(IO.ComfyNode): f"must equal the global duration ({duration}s)." ) + if resolution == "4k": + mode = "4k" + elif resolution == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -972,7 +982,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): prompt=prompt, aspect_ratio=aspect_ratio, duration=str(duration), - mode="pro" if resolution == "1080p" else "std", + mode=mode, multi_shot=multi_shot, multi_prompt=multi_prompt_list, shot_type="customize" if multi_shot else None, @@ -1014,7 +1024,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): optional=True, tooltip="Up to 6 additional reference images.", ), - IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True), IO.DynamicCombo.Input( "storyboards", options=[ @@ -1061,12 +1071,13 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $res := widgets.resolution; + $mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro"); $isV3 := $contains(widgets.model_name, "v3"); $audio := $isV3 and widgets.generate_audio; $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; + ? {"std": 0.112, "pro": 0.14, "4k": 0.42} + : {"std": 0.084, "pro": 0.112, "4k": 0.42}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -1093,6 +1104,8 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") if generate_audio: raise ValueError("kling-video-o1 does not support audio generation.") + if resolution == "4k": + raise ValueError("kling-video-o1 does not support 4k resolution.") stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" if stories_enabled and model_name == "kling-video-o1": raise ValueError("kling-video-o1 does not support storyboards.") @@ -1161,6 +1174,12 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): image_list.append(OmniParamImage(image_url=i)) + if resolution == "4k": + mode = "4k" + elif resolution == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -1170,7 +1189,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): prompt=prompt, duration=str(duration), image_list=image_list, - mode="pro" if resolution == "1080p" else "std", + mode=mode, sound="on" if generate_audio else "off", multi_shot=multi_shot, multi_prompt=multi_prompt_list, @@ -1204,7 +1223,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): "reference_images", tooltip="Up to 7 reference images.", ), - IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p", optional=True), IO.DynamicCombo.Input( "storyboards", options=[ @@ -1251,12 +1270,13 @@ class OmniProImageToVideoNode(IO.ComfyNode): depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]), expr=""" ( - $mode := (widgets.resolution = "720p") ? "std" : "pro"; + $res := widgets.resolution; + $mode := $res = "4k" ? "4k" : ($res = "720p" ? "std" : "pro"); $isV3 := $contains(widgets.model_name, "v3"); $audio := $isV3 and widgets.generate_audio; $rates := $audio - ? {"std": 0.112, "pro": 0.14} - : {"std": 0.084, "pro": 0.112}; + ? {"std": 0.112, "pro": 0.14, "4k": 0.42} + : {"std": 0.084, "pro": 0.112, "4k": 0.42}; {"type":"usd","usd": $lookup($rates, $mode) * widgets.duration} ) """, @@ -1282,6 +1302,8 @@ class OmniProImageToVideoNode(IO.ComfyNode): raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.") if generate_audio: raise ValueError("kling-video-o1 does not support audio generation.") + if resolution == "4k": + raise ValueError("kling-video-o1 does not support 4k resolution.") stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled" if stories_enabled and model_name == "kling-video-o1": raise ValueError("kling-video-o1 does not support storyboards.") @@ -1320,6 +1342,12 @@ class OmniProImageToVideoNode(IO.ComfyNode): image_list: list[OmniParamImage] = [] for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): image_list.append(OmniParamImage(image_url=i)) + if resolution == "4k": + mode = "4k" + elif resolution == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), @@ -1330,7 +1358,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): aspect_ratio=aspect_ratio, duration=str(duration), image_list=image_list, - mode="pro" if resolution == "1080p" else "std", + mode=mode, sound="on" if generate_audio else "off", multi_shot=multi_shot, multi_prompt=multi_prompt_list, @@ -2860,7 +2888,7 @@ class KlingVideoNode(IO.ComfyNode): IO.DynamicCombo.Option( "kling-v3", [ - IO.Combo.Input("resolution", options=["1080p", "720p"]), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"), IO.Combo.Input( "aspect_ratio", options=["16:9", "9:16", "1:1"], @@ -2913,7 +2941,11 @@ class KlingVideoNode(IO.ComfyNode): ), expr=""" ( - $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $rates := { + "4k": {"off": 0.42, "on": 0.42}, + "1080p": {"off": 0.112, "on": 0.168}, + "720p": {"off": 0.084, "on": 0.126} + }; $res := $lookup(widgets, "model.resolution"); $audio := widgets.generate_audio ? "on" : "off"; $rate := $lookup($lookup($rates, $res), $audio); @@ -2943,7 +2975,12 @@ class KlingVideoNode(IO.ComfyNode): start_frame: Input.Image | None = None, ) -> IO.NodeOutput: _ = seed - mode = "pro" if model["resolution"] == "1080p" else "std" + if model["resolution"] == "4k": + mode = "4k" + elif model["resolution"] == "1080p": + mode = "pro" + else: + mode = "std" custom_multi_shot = False if multi_shot["multi_shot"] == "disabled": shot_type = None @@ -3025,6 +3062,7 @@ class KlingVideoNode(IO.ComfyNode): cls, ApiEndpoint(path=poll_path), response_model=TaskStatusResponse, + max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) @@ -3057,7 +3095,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): IO.DynamicCombo.Option( "kling-v3", [ - IO.Combo.Input("resolution", options=["1080p", "720p"]), + IO.Combo.Input("resolution", options=["4k", "1080p", "720p"], default="1080p"), ], ), ], @@ -3089,7 +3127,11 @@ class KlingFirstLastFrameNode(IO.ComfyNode): ), expr=""" ( - $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}}; + $rates := { + "4k": {"off": 0.42, "on": 0.42}, + "1080p": {"off": 0.112, "on": 0.168}, + "720p": {"off": 0.084, "on": 0.126} + }; $res := $lookup(widgets, "model.resolution"); $audio := widgets.generate_audio ? "on" : "off"; $rate := $lookup($lookup($rates, $res), $audio); @@ -3118,6 +3160,12 @@ class KlingFirstLastFrameNode(IO.ComfyNode): validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame") image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame") + if model["resolution"] == "4k": + mode = "4k" + elif model["resolution"] == "1080p": + mode = "pro" + else: + mode = "std" response = await sync_op( cls, ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"), @@ -3127,7 +3175,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): image=image_url, image_tail=image_tail_url, prompt=prompt, - mode="pro" if model["resolution"] == "1080p" else "std", + mode=mode, duration=str(duration), sound="on" if generate_audio else "off", ), @@ -3140,6 +3188,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): cls, ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"), response_model=TaskStatusResponse, + max_poll_attempts=280, status_extractor=lambda r: (r.data.task_status if r.data else None), ) return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index 4ee896fa8..bbb758068 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -357,13 +357,17 @@ def calculate_tokens_price_image_1_5(response: OpenAIImageGenerationResponse) -> return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 32.0)) / 1_000_000.0 +def calculate_tokens_price_image_2_0(response: OpenAIImageGenerationResponse) -> float | None: + return ((response.usage.input_tokens * 8.0) + (response.usage.output_tokens * 30.0)) / 1_000_000.0 + + class OpenAIGPTImage1(IO.ComfyNode): @classmethod def define_schema(cls): return IO.Schema( node_id="OpenAIGPTImage1", - display_name="OpenAI GPT Image 1.5", + display_name="OpenAI GPT Image 2", category="api node/image/OpenAI", description="Generates images synchronously via OpenAI's GPT Image endpoint.", inputs=[ @@ -401,7 +405,17 @@ class OpenAIGPTImage1(IO.ComfyNode): IO.Combo.Input( "size", default="auto", - options=["auto", "1024x1024", "1024x1536", "1536x1024"], + options=[ + "auto", + "1024x1024", + "1024x1536", + "1536x1024", + "2048x2048", + "2048x1152", + "1152x2048", + "3840x2160", + "2160x3840", + ], tooltip="Image size", optional=True, ), @@ -427,8 +441,8 @@ class OpenAIGPTImage1(IO.ComfyNode): ), IO.Combo.Input( "model", - options=["gpt-image-1", "gpt-image-1.5"], - default="gpt-image-1.5", + options=["gpt-image-1", "gpt-image-1.5", "gpt-image-2"], + default="gpt-image-2", optional=True, ), ], @@ -442,23 +456,36 @@ class OpenAIGPTImage1(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["quality", "n"]), + depends_on=IO.PriceBadgeDepends(widgets=["quality", "n", "model"]), expr=""" ( $ranges := { - "low": [0.011, 0.02], - "medium": [0.046, 0.07], - "high": [0.167, 0.3] + "gpt-image-1": { + "low": [0.011, 0.02], + "medium": [0.042, 0.07], + "high": [0.167, 0.25] + }, + "gpt-image-1.5": { + "low": [0.009, 0.02], + "medium": [0.034, 0.062], + "high": [0.133, 0.22] + }, + "gpt-image-2": { + "low": [0.0048, 0.012], + "medium": [0.041, 0.112], + "high": [0.165, 0.43] + } }; - $range := $lookup($ranges, widgets.quality); - $n := widgets.n; + $range := $lookup($lookup($ranges, widgets.model), widgets.quality); + $nRaw := widgets.n; + $n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1; ($n = 1) - ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1]} + ? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}} : { "type":"range_usd", - "min_usd": $range[0], - "max_usd": $range[1], - "format": { "suffix": " x " & $string($n) & "/Run" } + "min_usd": $range[0] * $n, + "max_usd": $range[1] * $n, + "format": { "suffix": "/Run", "approximate": true } } ) """, @@ -483,10 +510,18 @@ class OpenAIGPTImage1(IO.ComfyNode): if mask is not None and image is None: raise ValueError("Cannot use a mask without an input image") + if model in ("gpt-image-1", "gpt-image-1.5"): + if size not in ("auto", "1024x1024", "1024x1536", "1536x1024"): + raise ValueError(f"Resolution {size} is only supported by GPT Image 2 model") + if model == "gpt-image-1": price_extractor = calculate_tokens_price_image_1 elif model == "gpt-image-1.5": price_extractor = calculate_tokens_price_image_1_5 + elif model == "gpt-image-2": + price_extractor = calculate_tokens_price_image_2_0 + if background == "transparent": + raise ValueError("Transparent background is not supported for GPT Image 2 model") else: raise ValueError(f"Unknown model: {model}") diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index afc18bb25..4d9075dcf 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -33,9 +33,13 @@ class OpenAIVideoSora2(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="OpenAIVideoSora2", - display_name="OpenAI Sora - Video", + display_name="OpenAI Sora - Video (Deprecated)", category="api node/video/Sora", - description="OpenAI video and audio generation.", + description=( + "OpenAI video and audio generation.\n\n" + "DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. " + "This node will be removed from ComfyUI at that time." + ), inputs=[ IO.Combo.Input( "model", diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 13fc1cc36..2ff75d9b2 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -24,8 +24,9 @@ from comfy_api_nodes.util import ( AVERAGE_DURATION_VIDEO_GEN = 32 MODELS_MAP = { "veo-2.0-generate-001": "veo-2.0-generate-001", - "veo-3.1-generate": "veo-3.1-generate-preview", - "veo-3.1-fast-generate": "veo-3.1-fast-generate-preview", + "veo-3.1-generate": "veo-3.1-generate-001", + "veo-3.1-fast-generate": "veo-3.1-fast-generate-001", + "veo-3.1-lite": "veo-3.1-lite-generate-001", "veo-3.0-generate-001": "veo-3.0-generate-001", "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", } @@ -247,17 +248,8 @@ class VeoVideoGenerationNode(IO.ComfyNode): raise Exception("Video generation completed but no video was returned") -class Veo3VideoGenerationNode(VeoVideoGenerationNode): - """ - Generates videos from text prompts using Google's Veo 3 API. - - Supported models: - - veo-3.0-generate-001 - - veo-3.0-fast-generate-001 - - This node extends the base Veo node with Veo 3 specific features including - audio generation and fixed 8-second duration. - """ +class Veo3VideoGenerationNode(IO.ComfyNode): + """Generates videos from text prompts using Google's Veo 3 API.""" @classmethod def define_schema(cls): @@ -279,6 +271,13 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): default="16:9", tooltip="Aspect ratio of the output video", ), + IO.Combo.Input( + "resolution", + options=["720p", "1080p", "4k"], + default="720p", + tooltip="Output video resolution. 4K is not available for veo-3.1-lite and veo-3.0 models.", + optional=True, + ), IO.String.Input( "negative_prompt", multiline=True, @@ -289,11 +288,11 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Int.Input( "duration_seconds", default=8, - min=8, + min=4, max=8, - step=1, + step=2, display_mode=IO.NumberDisplay.number, - tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)", + tooltip="Duration of the output video in seconds", optional=True, ), IO.Boolean.Input( @@ -332,10 +331,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): options=[ "veo-3.1-generate", "veo-3.1-fast-generate", + "veo-3.1-lite", "veo-3.0-generate-001", "veo-3.0-fast-generate-001", ], - default="veo-3.0-generate-001", tooltip="Veo 3 model to use for video generation", optional=True, ), @@ -356,21 +355,111 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "resolution", "duration_seconds"]), expr=""" ( $m := widgets.model; + $r := widgets.resolution; $a := widgets.generate_audio; - ($contains($m,"veo-3.0-fast-generate-001") or $contains($m,"veo-3.1-fast-generate")) - ? {"type":"usd","usd": ($a ? 1.2 : 0.8)} - : ($contains($m,"veo-3.0-generate-001") or $contains($m,"veo-3.1-generate")) - ? {"type":"usd","usd": ($a ? 3.2 : 1.6)} - : {"type":"range_usd","min_usd":0.8,"max_usd":3.2} + $seconds := widgets.duration_seconds; + $pps := + $contains($m, "lite") + ? ($r = "1080p" ? ($a ? 0.08 : 0.05) : ($a ? 0.05 : 0.03)) + : $contains($m, "3.1-fast") + ? ($r = "4k" ? ($a ? 0.30 : 0.25) : $r = "1080p" ? ($a ? 0.12 : 0.10) : ($a ? 0.10 : 0.08)) + : $contains($m, "3.1-generate") + ? ($r = "4k" ? ($a ? 0.60 : 0.40) : ($a ? 0.40 : 0.20)) + : $contains($m, "3.0-fast") + ? ($a ? 0.15 : 0.10) + : ($a ? 0.40 : 0.20); + {"type":"usd","usd": $pps * $seconds} ) """, ), ) + @classmethod + async def execute( + cls, + prompt, + aspect_ratio="16:9", + resolution="720p", + negative_prompt="", + duration_seconds=8, + enhance_prompt=True, + person_generation="ALLOW", + seed=0, + image=None, + model="veo-3.0-generate-001", + generate_audio=False, + ): + if resolution == "4k" and ("lite" in model or "3.0" in model): + raise Exception("4K resolution is not supported by the veo-3.1-lite or veo-3.0 models.") + + model = MODELS_MAP[model] + + instances = [{"prompt": prompt}] + if image is not None: + image_base64 = tensor_to_base64_string(image) + if image_base64: + instances[0]["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"} + + parameters = { + "aspectRatio": aspect_ratio, + "personGeneration": person_generation, + "durationSeconds": duration_seconds, + "enhancePrompt": True, + "generateAudio": generate_audio, + } + if negative_prompt: + parameters["negativePrompt"] = negative_prompt + if seed > 0: + parameters["seed"] = seed + if "veo-3.1" in model: + parameters["resolution"] = resolution + + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( + instances=instances, + parameters=parameters, + ), + ) + + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, + status_extractor=lambda r: "completed" if r.done else "pending", + data=VeoGenVidPollRequest(operationName=initial_response.name), + poll_interval=9.0, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, + ) + + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") + + response = poll_response.response + filtered_count = response.raiMediaFilteredCount + if filtered_count: + reasons = response.raiMediaFilteredReasons or [] + reason_part = f": {reasons[0]}" if reasons else "" + raise Exception( + f"Content blocked by Google's Responsible AI filters{reason_part} " + f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)." + ) + + if response.videos: + video = response.videos[0] + if video.bytesBase64Encoded: + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + if video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") + class Veo3FirstLastFrameNode(IO.ComfyNode): @@ -394,7 +483,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): default="", tooltip="Negative text prompt to guide what to avoid in the video", ), - IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input("resolution", options=["720p", "1080p", "4k"]), IO.Combo.Input( "aspect_ratio", options=["16:9", "9:16"], @@ -424,8 +513,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): IO.Image.Input("last_frame", tooltip="End frame"), IO.Combo.Input( "model", - options=["veo-3.1-generate", "veo-3.1-fast-generate"], - default="veo-3.1-fast-generate", + options=["veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.1-lite"], ), IO.Boolean.Input( "generate_audio", @@ -443,26 +531,20 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "generate_audio", "duration", "resolution"]), expr=""" ( - $prices := { - "veo-3.1-fast-generate": { "audio": 0.15, "no_audio": 0.10 }, - "veo-3.1-generate": { "audio": 0.40, "no_audio": 0.20 } - }; $m := widgets.model; - $ga := (widgets.generate_audio = "true"); + $r := widgets.resolution; + $ga := widgets.generate_audio; $seconds := widgets.duration; - $modelKey := - $contains($m, "veo-3.1-fast-generate") ? "veo-3.1-fast-generate" : - $contains($m, "veo-3.1-generate") ? "veo-3.1-generate" : - ""; - $audioKey := $ga ? "audio" : "no_audio"; - $modelPrices := $lookup($prices, $modelKey); - $pps := $lookup($modelPrices, $audioKey); - ($pps != null) - ? {"type":"usd","usd": $pps * $seconds} - : {"type":"range_usd","min_usd": 0.4, "max_usd": 3.2} + $pps := + $contains($m, "lite") + ? ($r = "1080p" ? ($ga ? 0.08 : 0.05) : ($ga ? 0.05 : 0.03)) + : $contains($m, "fast") + ? ($r = "4k" ? ($ga ? 0.30 : 0.25) : $r = "1080p" ? ($ga ? 0.12 : 0.10) : ($ga ? 0.10 : 0.08)) + : ($r = "4k" ? ($ga ? 0.60 : 0.40) : ($ga ? 0.40 : 0.20)); + {"type":"usd","usd": $pps * $seconds} ) """, ), @@ -482,6 +564,9 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): model: str, generate_audio: bool, ): + if "lite" in model and resolution == "4k": + raise Exception("4K resolution is not supported by the veo-3.1-lite model.") + model = MODELS_MAP[model] initial_response = await sync_op( cls, @@ -519,7 +604,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): data=VeoGenVidPollRequest( operationName=initial_response.name, ), - poll_interval=5.0, + poll_interval=9.0, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index 0cb9a47c7..f3584aba9 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -19,6 +19,7 @@ from .conversions import ( image_tensor_pair_to_batch, pil_to_bytesio, resize_mask_to_image, + resize_video_to_pixel_budget, tensor_to_base64_string, tensor_to_bytesio, tensor_to_pil, @@ -90,6 +91,7 @@ __all__ = [ "image_tensor_pair_to_batch", "pil_to_bytesio", "resize_mask_to_image", + "resize_video_to_pixel_budget", "tensor_to_base64_string", "tensor_to_bytesio", "tensor_to_pil", diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 9d730b81a..b0cf97ae4 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -156,6 +156,7 @@ async def poll_op( estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, + extra_text: str | None = None, ) -> M: raw = await poll_op_raw( cls, @@ -176,6 +177,7 @@ async def poll_op( estimated_duration=estimated_duration, cancel_endpoint=cancel_endpoint, cancel_timeout=cancel_timeout, + extra_text=extra_text, ) if not isinstance(raw, dict): raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") @@ -260,6 +262,7 @@ async def poll_op_raw( estimated_duration: int | None = None, cancel_endpoint: ApiEndpoint | None = None, cancel_timeout: float = 10.0, + extra_text: str | None = None, ) -> dict[str, Any]: """ Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, @@ -299,6 +302,7 @@ async def poll_op_raw( price=state.price, is_queued=state.is_queued, processing_elapsed_seconds=int(proc_elapsed), + extra_text=extra_text, ) await asyncio.sleep(1.0) except Exception as exc: @@ -389,6 +393,7 @@ async def poll_op_raw( price=state.price, is_queued=False, processing_elapsed_seconds=int(state.base_processing_elapsed), + extra_text=extra_text, ) return resp_json @@ -462,6 +467,7 @@ def _display_time_progress( price: float | None = None, is_queued: bool | None = None, processing_elapsed_seconds: int | None = None, + extra_text: str | None = None, ) -> None: if estimated_total is not None and estimated_total > 0 and is_queued is False: pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds @@ -469,7 +475,8 @@ def _display_time_progress( time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" else: time_line = f"Time elapsed: {int(elapsed_seconds)}s" - _display_text(node_cls, time_line, status=status, price=price) + text = f"{time_line}\n\n{extra_text}" if extra_text else time_line + _display_text(node_cls, text, status=status, price=price) async def _diagnose_connectivity() -> dict[str, bool]: diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 82b6d22a5..be5d5719b 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -129,22 +129,38 @@ def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: return img_byte_arr +def _compute_downscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None: + """Return downscaled (w, h) with even dims fitting ``total_pixels``, or None if already fits. + + Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions + are rounded down to even values (many codecs require divisible-by-2). + """ + pixels = src_w * src_h + if pixels <= total_pixels: + return None + scale = math.sqrt(total_pixels / pixels) + new_w = max(2, int(src_w * scale)) + new_h = max(2, int(src_h * scale)) + new_w -= new_w % 2 + new_h -= new_h % 2 + return new_w, new_h + + def downscale_image_tensor(image: torch.Tensor, total_pixels: int = 1536 * 1024) -> torch.Tensor: - """Downscale input image tensor to roughly the specified total pixels.""" + """Downscale input image tensor to roughly the specified total pixels. + + Output dimensions are rounded down to even values so that the result is guaranteed to fit within ``total_pixels`` + and is compatible with codecs that require even dimensions (e.g. yuv420p). + """ samples = image.movedim(-1, 1) - total = int(total_pixels) - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - if scale_by >= 1: + dims = _compute_downscale_dims(samples.shape[3], samples.shape[2], int(total_pixels)) + if dims is None: return image - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = common_upscale(samples, width, height, "lanczos", "disabled") - s = s.movedim(1, -1) - return s + new_w, new_h = dims + return common_upscale(samples, new_w, new_h, "lanczos", "disabled").movedim(1, -1) -def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: +def downscale_image_tensor_by_max_side(image: torch.Tensor, *, max_side: int) -> torch.Tensor: """Downscale input image tensor so the largest dimension is at most max_side pixels.""" samples = image.movedim(-1, 1) height, width = samples.shape[2], samples.shape[3] @@ -399,6 +415,72 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: raise RuntimeError(f"Failed to trim video: {str(e)}") from e +def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video: + """Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio. + + Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio. + Aspect ratio is preserved up to a fraction of a percent (even-dim rounding). + """ + src_w, src_h = video.get_dimensions() + scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels) + if scale_dims is None: + return video + return _apply_video_scale(video, scale_dims) + + +def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input.Video: + """Re-encode ``video`` scaled to ``scale_dims`` with a single decode/encode pass.""" + out_w, out_h = scale_dims + output_buffer = BytesIO() + input_container = None + output_container = None + + try: + input_source = video.get_stream_source() + input_container = av.open(input_source, mode="r") + output_container = av.open(output_buffer, mode="w", format="mp4") + + video_stream = output_container.add_stream("h264", rate=video.get_frame_rate()) + video_stream.width = out_w + video_stream.height = out_h + video_stream.pix_fmt = "yuv420p" + + audio_stream = None + for stream in input_container.streams: + if isinstance(stream, av.AudioStream): + audio_stream = output_container.add_stream("aac", rate=stream.sample_rate) + audio_stream.sample_rate = stream.sample_rate + audio_stream.layout = stream.layout + break + + for frame in input_container.decode(video=0): + frame = frame.reformat(width=out_w, height=out_h, format="yuv420p") + for packet in video_stream.encode(frame): + output_container.mux(packet) + for packet in video_stream.encode(): + output_container.mux(packet) + + if audio_stream is not None: + input_container.seek(0) + for audio_frame in input_container.decode(audio=0): + for packet in audio_stream.encode(audio_frame): + output_container.mux(packet) + for packet in audio_stream.encode(): + output_container.mux(packet) + + output_container.close() + input_container.close() + output_buffer.seek(0) + return InputImpl.VideoFromFile(output_buffer) + + except Exception as e: + if input_container is not None: + input_container.close() + if output_container is not None: + output_container.close() + raise RuntimeError(f"Failed to resize video: {str(e)}") from e + + def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" if wav.dtype.is_floating_point: diff --git a/comfy_extras/frame_interpolation_models/film_net.py b/comfy_extras/frame_interpolation_models/film_net.py new file mode 100644 index 000000000..cf4f6e1e1 --- /dev/null +++ b/comfy_extras/frame_interpolation_models/film_net.py @@ -0,0 +1,258 @@ +"""FILM: Frame Interpolation for Large Motion (ECCV 2022).""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +class FilmConv2d(nn.Module): + """Conv2d with optional LeakyReLU and FILM-style padding.""" + + def __init__(self, in_channels, out_channels, size, activation=True, device=None, dtype=None, operations=ops): + super().__init__() + self.even_pad = not size % 2 + self.conv = operations.Conv2d(in_channels, out_channels, kernel_size=size, padding=size // 2 if size % 2 else 0, device=device, dtype=dtype) + self.activation = nn.LeakyReLU(0.2) if activation else None + + def forward(self, x): + if self.even_pad: + x = F.pad(x, (0, 1, 0, 1)) + x = self.conv(x) + if self.activation is not None: + x = self.activation(x) + return x + + +def _warp_core(image, flow, grid_x, grid_y): + dtype = image.dtype + H, W = flow.shape[2], flow.shape[3] + dx = flow[:, 0].float() / (W * 0.5) + dy = flow[:, 1].float() / (H * 0.5) + grid = torch.stack([grid_x[None, None, :] + dx, grid_y[None, :, None] + dy], dim=3) + return F.grid_sample(image.float(), grid, mode="bilinear", padding_mode="border", align_corners=False).to(dtype) + + +def build_image_pyramid(image, pyramid_levels): + pyramid = [image] + for _ in range(1, pyramid_levels): + image = F.avg_pool2d(image, 2, 2) + pyramid.append(image) + return pyramid + + +def flow_pyramid_synthesis(residual_pyramid): + flow = residual_pyramid[-1] + flow_pyramid = [flow] + for residual_flow in residual_pyramid[:-1][::-1]: + flow = F.interpolate(flow, size=residual_flow.shape[2:4], mode="bilinear", scale_factor=None).mul_(2).add_(residual_flow) + flow_pyramid.append(flow) + flow_pyramid.reverse() + return flow_pyramid + + +def multiply_pyramid(pyramid, scalar): + return [image * scalar[:, None, None, None] for image in pyramid] + + +def pyramid_warp(feature_pyramid, flow_pyramid, warp_fn): + return [warp_fn(features, flow) for features, flow in zip(feature_pyramid, flow_pyramid)] + + +def concatenate_pyramids(pyramid1, pyramid2): + return [torch.cat([f1, f2], dim=1) for f1, f2 in zip(pyramid1, pyramid2)] + + +class SubTreeExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, n_layers=4, device=None, dtype=None, operations=ops): + super().__init__() + convs = [] + for i in range(n_layers): + out_ch = channels << i + convs.append(nn.Sequential( + FilmConv2d(in_channels, out_ch, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(out_ch, out_ch, 3, device=device, dtype=dtype, operations=operations))) + in_channels = out_ch + self.convs = nn.ModuleList(convs) + + def forward(self, image, n): + head = image + pyramid = [] + for i, layer in enumerate(self.convs): + head = layer(head) + pyramid.append(head) + if i < n - 1: + head = F.avg_pool2d(head, 2, 2) + return pyramid + + +class FeatureExtractor(nn.Module): + def __init__(self, in_channels=3, channels=64, sub_levels=4, device=None, dtype=None, operations=ops): + super().__init__() + self.extract_sublevels = SubTreeExtractor(in_channels, channels, sub_levels, device=device, dtype=dtype, operations=operations) + self.sub_levels = sub_levels + + def forward(self, image_pyramid): + sub_pyramids = [self.extract_sublevels(image_pyramid[i], min(len(image_pyramid) - i, self.sub_levels)) + for i in range(len(image_pyramid))] + feature_pyramid = [] + for i in range(len(image_pyramid)): + features = sub_pyramids[i][0] + for j in range(1, self.sub_levels): + if j <= i: + features = torch.cat([features, sub_pyramids[i - j][j]], dim=1) + feature_pyramid.append(features) + # Free sub-pyramids no longer needed by future levels + if i >= self.sub_levels - 1: + sub_pyramids[i - self.sub_levels + 1] = None + return feature_pyramid + + +class FlowEstimator(nn.Module): + def __init__(self, in_channels, num_convs, num_filters, device=None, dtype=None, operations=ops): + super().__init__() + self._convs = nn.ModuleList() + for _ in range(num_convs): + self._convs.append(FilmConv2d(in_channels, num_filters, 3, device=device, dtype=dtype, operations=operations)) + in_channels = num_filters + self._convs.append(FilmConv2d(in_channels, num_filters // 2, 1, device=device, dtype=dtype, operations=operations)) + self._convs.append(FilmConv2d(num_filters // 2, 2, 1, activation=False, device=device, dtype=dtype, operations=operations)) + + def forward(self, features_a, features_b): + net = torch.cat([features_a, features_b], dim=1) + for conv in self._convs: + net = conv(net) + return net + + +class PyramidFlowEstimator(nn.Module): + def __init__(self, filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + in_channels = filters << 1 + predictors = [] + for i in range(len(flow_convs)): + predictors.append(FlowEstimator(in_channels, flow_convs[i], flow_filters[i], device=device, dtype=dtype, operations=operations)) + in_channels += filters << (i + 2) + self._predictor = predictors[-1] + self._predictors = nn.ModuleList(predictors[:-1][::-1]) + + def forward(self, feature_pyramid_a, feature_pyramid_b, warp_fn): + levels = len(feature_pyramid_a) + v = self._predictor(feature_pyramid_a[-1], feature_pyramid_b[-1]) + residuals = [v] + # Coarse-to-fine: shared predictor for deep levels, then specialized predictors for fine levels + steps = [(i, self._predictor) for i in range(levels - 2, len(self._predictors) - 1, -1)] + steps += [(len(self._predictors) - 1 - k, p) for k, p in enumerate(self._predictors)] + for i, predictor in steps: + v = F.interpolate(v, size=feature_pyramid_a[i].shape[2:4], mode="bilinear").mul_(2) + v_residual = predictor(feature_pyramid_a[i], warp_fn(feature_pyramid_b[i], v)) + residuals.append(v_residual) + v = v.add_(v_residual) + residuals.reverse() + return residuals + + +def _get_fusion_channels(level, filters): + # Per direction: multi-scale features + RGB image (3ch) + flow (2ch), doubled for both directions + return (sum(filters << i for i in range(level)) + 3 + 2) * 2 + + +class Fusion(nn.Module): + def __init__(self, n_layers=4, specialized_layers=3, filters=64, device=None, dtype=None, operations=ops): + super().__init__() + self.output_conv = operations.Conv2d(filters, 3, kernel_size=1, device=device, dtype=dtype) + self.convs = nn.ModuleList() + in_channels = _get_fusion_channels(n_layers, filters) + increase = 0 + for i in range(n_layers)[::-1]: + num_filters = (filters << i) if i < specialized_layers else (filters << specialized_layers) + self.convs.append(nn.ModuleList([ + FilmConv2d(in_channels, num_filters, 2, activation=False, device=device, dtype=dtype, operations=operations), + FilmConv2d(in_channels + (increase or num_filters), num_filters, 3, device=device, dtype=dtype, operations=operations), + FilmConv2d(num_filters, num_filters, 3, device=device, dtype=dtype, operations=operations)])) + in_channels = num_filters + increase = _get_fusion_channels(i, filters) - num_filters // 2 + + def forward(self, pyramid): + net = pyramid[-1] + for k, layers in enumerate(self.convs): + i = len(self.convs) - 1 - k + net = layers[0](F.interpolate(net, size=pyramid[i].shape[2:4], mode="nearest")) + net = layers[2](layers[1](torch.cat([pyramid[i], net], dim=1))) + return self.output_conv(net) + + +class FILMNet(nn.Module): + def __init__(self, pyramid_levels=7, fusion_pyramid_levels=5, specialized_levels=3, sub_levels=4, + filters=64, flow_convs=(3, 3, 3, 3), flow_filters=(32, 64, 128, 256), device=None, dtype=None, operations=ops): + super().__init__() + self.pyramid_levels = pyramid_levels + self.fusion_pyramid_levels = fusion_pyramid_levels + self.extract = FeatureExtractor(3, filters, sub_levels, device=device, dtype=dtype, operations=operations) + self.predict_flow = PyramidFlowEstimator(filters, flow_convs, flow_filters, device=device, dtype=dtype, operations=operations) + self.fuse = Fusion(sub_levels, specialized_levels, filters, device=device, dtype=dtype, operations=operations) + self._warp_grids = {} + + def get_dtype(self): + return self.extract.extract_sublevels.convs[0][0].conv.weight.dtype + + def _build_warp_grids(self, H, W, device): + """Pre-compute warp grids for all pyramid levels.""" + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + for _ in range(self.pyramid_levels): + self._warp_grids[(H, W)] = ( + torch.linspace(-(1 - 1 / W), 1 - 1 / W, W, dtype=torch.float32, device=device), + torch.linspace(-(1 - 1 / H), 1 - 1 / H, H, dtype=torch.float32, device=device), + ) + H, W = H // 2, W // 2 + + def warp(self, image, flow): + grid_x, grid_y = self._warp_grids[(flow.shape[2], flow.shape[3])] + return _warp_core(image, flow, grid_x, grid_y) + + def extract_features(self, img): + """Extract image and feature pyramids for a single frame. Can be cached across pairs.""" + image_pyramid = build_image_pyramid(img, self.pyramid_levels) + feature_pyramid = self.extract(image_pyramid) + return image_pyramid, feature_pyramid + + def forward(self, img0, img1, timestep=0.5, cache=None): + # FILM uses a scalar timestep per batch element (spatially-varying timesteps not supported) + t = timestep.mean(dim=(1, 2, 3)).item() if isinstance(timestep, torch.Tensor) else timestep + return self.forward_multi_timestep(img0, img1, [t], cache=cache) + + def forward_multi_timestep(self, img0, img1, timesteps, cache=None): + """Compute flow once, synthesize at multiple timesteps. Expects batch=1 inputs.""" + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + image_pyr0, feat_pyr0 = cache["img0"] if cache and "img0" in cache else self.extract_features(img0) + image_pyr1, feat_pyr1 = cache["img1"] if cache and "img1" in cache else self.extract_features(img1) + + fwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr0, feat_pyr1, self.warp))[:self.fusion_pyramid_levels] + bwd_flow = flow_pyramid_synthesis(self.predict_flow(feat_pyr1, feat_pyr0, self.warp))[:self.fusion_pyramid_levels] + + # Build warp targets and free full pyramids (only first fpl levels needed from here) + fpl = self.fusion_pyramid_levels + p2w = [concatenate_pyramids(image_pyr0[:fpl], feat_pyr0[:fpl]), + concatenate_pyramids(image_pyr1[:fpl], feat_pyr1[:fpl])] + del image_pyr0, image_pyr1, feat_pyr0, feat_pyr1 + + results = [] + dt_tensors = torch.tensor(timesteps, device=img0.device, dtype=img0.dtype) + for idx in range(len(timesteps)): + batch_dt = dt_tensors[idx:idx + 1] + bwd_scaled = multiply_pyramid(bwd_flow, batch_dt) + fwd_scaled = multiply_pyramid(fwd_flow, 1 - batch_dt) + fwd_warped = pyramid_warp(p2w[0], bwd_scaled, self.warp) + bwd_warped = pyramid_warp(p2w[1], fwd_scaled, self.warp) + aligned = [torch.cat([fw, bw, bf, ff], dim=1) + for fw, bw, bf, ff in zip(fwd_warped, bwd_warped, bwd_scaled, fwd_scaled)] + del fwd_warped, bwd_warped, bwd_scaled, fwd_scaled + results.append(self.fuse(aligned)) + del aligned + return torch.cat(results, dim=0) diff --git a/comfy_extras/frame_interpolation_models/ifnet.py b/comfy_extras/frame_interpolation_models/ifnet.py new file mode 100644 index 000000000..03cb34c50 --- /dev/null +++ b/comfy_extras/frame_interpolation_models/ifnet.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops + +ops = comfy.ops.disable_weight_init + + +def _warp(img, flow, warp_grids): + B, _, H, W = img.shape + base_grid, flow_div = warp_grids[(H, W)] + flow_norm = torch.cat([flow[:, 0:1] / flow_div[0], flow[:, 1:2] / flow_div[1]], 1).float() + grid = (base_grid.expand(B, -1, -1, -1) + flow_norm).permute(0, 2, 3, 1) + return F.grid_sample(img.float(), grid, mode="bilinear", padding_mode="border", align_corners=True).to(img.dtype) + + +class Head(nn.Module): + def __init__(self, out_ch=4, device=None, dtype=None, operations=ops): + super().__init__() + self.cnn0 = operations.Conv2d(3, 16, 3, 2, 1, device=device, dtype=dtype) + self.cnn1 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn2 = operations.Conv2d(16, 16, 3, 1, 1, device=device, dtype=dtype) + self.cnn3 = operations.ConvTranspose2d(16, out_ch, 4, 2, 1, device=device, dtype=dtype) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + x = self.relu(self.cnn0(x)) + x = self.relu(self.cnn1(x)) + x = self.relu(self.cnn2(x)) + return self.cnn3(x) + + +class ResConv(nn.Module): + def __init__(self, c, device=None, dtype=None, operations=ops): + super().__init__() + self.conv = operations.Conv2d(c, c, 3, 1, 1, device=device, dtype=dtype) + self.beta = nn.Parameter(torch.ones((1, c, 1, 1), device=device, dtype=dtype)) + self.relu = nn.LeakyReLU(0.2, True) + + def forward(self, x): + return self.relu(torch.addcmul(x, self.conv(x), self.beta)) + + +class IFBlock(nn.Module): + def __init__(self, in_planes, c=64, device=None, dtype=None, operations=ops): + super().__init__() + self.conv0 = nn.Sequential( + nn.Sequential(operations.Conv2d(in_planes, c // 2, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True)), + nn.Sequential(operations.Conv2d(c // 2, c, 3, 2, 1, device=device, dtype=dtype), nn.LeakyReLU(0.2, True))) + self.convblock = nn.Sequential(*(ResConv(c, device=device, dtype=dtype, operations=operations) for _ in range(8))) + self.lastconv = nn.Sequential(operations.ConvTranspose2d(c, 4 * 13, 4, 2, 1, device=device, dtype=dtype), nn.PixelShuffle(2)) + + def forward(self, x, flow=None, scale=1): + x = F.interpolate(x, scale_factor=1.0 / scale, mode="bilinear") + if flow is not None: + flow = F.interpolate(flow, scale_factor=1.0 / scale, mode="bilinear").div_(scale) + x = torch.cat((x, flow), 1) + feat = self.convblock(self.conv0(x)) + tmp = F.interpolate(self.lastconv(feat), scale_factor=scale, mode="bilinear") + return tmp[:, :4] * scale, tmp[:, 4:5], tmp[:, 5:] + + +class IFNet(nn.Module): + def __init__(self, head_ch=4, channels=(192, 128, 96, 64, 32), device=None, dtype=None, operations=ops): + super().__init__() + self.encode = Head(out_ch=head_ch, device=device, dtype=dtype, operations=operations) + block_in = [7 + 2 * head_ch] + [8 + 4 + 8 + 2 * head_ch] * 4 + self.blocks = nn.ModuleList([IFBlock(block_in[i], channels[i], device=device, dtype=dtype, operations=operations) for i in range(5)]) + self.scale_list = [16, 8, 4, 2, 1] + self.pad_align = 64 + self._warp_grids = {} + + def get_dtype(self): + return self.encode.cnn0.weight.dtype + + def _build_warp_grids(self, H, W, device): + if (H, W) in self._warp_grids: + return + self._warp_grids = {} # clear old resolution grids to prevent memory leaks + grid_y, grid_x = torch.meshgrid( + torch.linspace(-1.0, 1.0, H, device=device, dtype=torch.float32), + torch.linspace(-1.0, 1.0, W, device=device, dtype=torch.float32), indexing="ij") + self._warp_grids[(H, W)] = ( + torch.stack((grid_x, grid_y), dim=0).unsqueeze(0), + torch.tensor([(W - 1.0) / 2.0, (H - 1.0) / 2.0], dtype=torch.float32, device=device)) + + def warp(self, img, flow): + return _warp(img, flow, self._warp_grids) + + def extract_features(self, img): + """Extract head features for a single frame. Can be cached across pairs.""" + return self.encode(img) + + def forward(self, img0, img1, timestep=0.5, cache=None): + if not isinstance(timestep, torch.Tensor): + timestep = torch.full((img0.shape[0], 1, img0.shape[2], img0.shape[3]), timestep, device=img0.device, dtype=img0.dtype) + + self._build_warp_grids(img0.shape[2], img0.shape[3], img0.device) + + B = img0.shape[0] + f0 = cache["img0"].expand(B, -1, -1, -1) if cache and "img0" in cache else self.encode(img0) + f1 = cache["img1"].expand(B, -1, -1, -1) if cache and "img1" in cache else self.encode(img1) + flow = mask = feat = None + warped_img0, warped_img1 = img0, img1 + for i, block in enumerate(self.blocks): + if flow is None: + flow, mask, feat = block(torch.cat((img0, img1, f0, f1, timestep), 1), None, scale=self.scale_list[i]) + else: + fd, mask, feat = block( + torch.cat((warped_img0, warped_img1, self.warp(f0, flow[:, :2]), self.warp(f1, flow[:, 2:4]), timestep, mask, feat), 1), + flow, scale=self.scale_list[i]) + flow = flow.add_(fd) + warped_img0 = self.warp(img0, flow[:, :2]) + warped_img1 = self.warp(img1, flow[:, 2:4]) + return torch.lerp(warped_img1, warped_img0, torch.sigmoid(mask)) + + +def detect_rife_config(state_dict): + head_ch = state_dict["encode.cnn3.weight"].shape[1] # ConvTranspose2d: (in_ch, out_ch, kH, kW) + channels = [] + for i in range(5): + key = f"blocks.{i}.conv0.1.0.weight" + if key in state_dict: + channels.append(state_dict[key].shape[0]) + if len(channels) != 5: + raise ValueError(f"Unsupported RIFE model: expected 5 blocks, found {len(channels)}") + return head_ch, channels diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index a395392d8..5f514716f 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -104,7 +104,7 @@ def vae_decode_audio(vae, samples, tile=None, overlap=None): std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - vae_sample_rate = getattr(vae, "audio_sample_rate", 44100) + vae_sample_rate = getattr(vae, "audio_sample_rate_output", getattr(vae, "audio_sample_rate", 44100)) return {"waveform": audio, "sample_rate": vae_sample_rate if "sample_rate" not in samples else samples["sample_rate"]} diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py new file mode 100644 index 000000000..a3b00d36e --- /dev/null +++ b/comfy_extras/nodes_frame_interpolation.py @@ -0,0 +1,211 @@ +import torch +from tqdm import tqdm +from typing_extensions import override + +import comfy.model_patcher +import comfy.utils +import folder_paths +from comfy import model_management +from comfy_extras.frame_interpolation_models.ifnet import IFNet, detect_rife_config +from comfy_extras.frame_interpolation_models.film_net import FILMNet +from comfy_api.latest import ComfyExtension, io + +FrameInterpolationModel = io.Custom("INTERP_MODEL") + + +class FrameInterpolationModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolationModelLoader", + display_name="Load Frame Interpolation Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"), + tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."), + ], + outputs=[ + FrameInterpolationModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("frame_interpolation", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + model = cls._detect_and_load(sd) + dtype = torch.float16 if model_management.should_use_fp16(model_management.get_torch_device()) else torch.float32 + model.eval().to(dtype) + patcher = comfy.model_patcher.ModelPatcher( + model, + load_device=model_management.get_torch_device(), + offload_device=model_management.unet_offload_device(), + ) + return io.NodeOutput(patcher) + + @classmethod + def _detect_and_load(cls, sd): + # Try FILM + if "extract.extract_sublevels.convs.0.0.conv.weight" in sd: + model = FILMNet() + model.load_state_dict(sd) + return model + + # Try RIFE (needs key remapping for raw checkpoints) + sd = comfy.utils.state_dict_prefix_replace(sd, {"module.": "", "flownet.": ""}) + key_map = {} + for k in sd: + for i in range(5): + if k.startswith(f"block{i}."): + key_map[k] = f"blocks.{i}.{k[len(f'block{i}.'):]}" + if key_map: + sd = {key_map.get(k, k): v for k, v in sd.items()} + sd = {k: v for k, v in sd.items() if not k.startswith(("teacher.", "caltime."))} + + try: + head_ch, channels = detect_rife_config(sd) + except (KeyError, ValueError): + raise ValueError("Unrecognized frame interpolation model format") + model = IFNet(head_ch=head_ch, channels=channels) + model.load_state_dict(sd) + return model + + +class FrameInterpolate(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="FrameInterpolate", + display_name="Frame Interpolate", + category="image/video", + search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"], + inputs=[ + FrameInterpolationModel.Input("interp_model"), + io.Image.Input("images"), + io.Int.Input("multiplier", default=2, min=2, max=16), + ], + outputs=[ + io.Image.Output(), + ], + ) + + @classmethod + def execute(cls, interp_model, images, multiplier) -> io.NodeOutput: + offload_device = model_management.intermediate_device() + + num_frames = images.shape[0] + if num_frames < 2 or multiplier < 2: + return io.NodeOutput(images) + + model_management.load_model_gpu(interp_model) + device = interp_model.load_device + dtype = interp_model.model_dtype() + inference_model = interp_model.model + + # Free VRAM for inference activations (model weights + ~20x a single frame's worth) + H, W = images.shape[1], images.shape[2] + activation_mem = H * W * 3 * images.element_size() * 20 + model_management.free_memory(activation_mem, device) + align = getattr(inference_model, "pad_align", 1) + + # Prepare a single padded frame on device for determining output dimensions + def prepare_frame(idx): + frame = images[idx:idx + 1].movedim(-1, 1).to(dtype=dtype, device=device) + if align > 1: + from comfy.ldm.common_dit import pad_to_patch_size + frame = pad_to_patch_size(frame, (align, align), padding_mode="reflect") + return frame + + # Count total interpolation passes for progress bar + total_pairs = num_frames - 1 + num_interp = multiplier - 1 + total_steps = total_pairs * num_interp + pbar = comfy.utils.ProgressBar(total_steps) + tqdm_bar = tqdm(total=total_steps, desc="Frame interpolation") + + batch = num_interp # reduced on OOM and persists across pairs (same resolution = same limit) + t_values = [t / multiplier for t in range(1, multiplier)] + + out_dtype = model_management.intermediate_dtype() + total_out_frames = total_pairs * multiplier + 1 + result = torch.empty((total_out_frames, 3, H, W), dtype=out_dtype, device=offload_device) + result[0] = images[0].movedim(-1, 0).to(out_dtype) + out_idx = 1 + + # Pre-compute timestep tensor on device (padded dimensions needed) + sample = prepare_frame(0) + pH, pW = sample.shape[2], sample.shape[3] + ts_full = torch.tensor(t_values, device=device, dtype=dtype).reshape(num_interp, 1, 1, 1) + ts_full = ts_full.expand(-1, 1, pH, pW) + del sample + + multi_fn = getattr(inference_model, "forward_multi_timestep", None) + feat_cache = {} + prev_frame = None + + try: + for i in range(total_pairs): + img0_single = prev_frame if prev_frame is not None else prepare_frame(i) + img1_single = prepare_frame(i + 1) + prev_frame = img1_single + + # Cache features: img1 of pair N becomes img0 of pair N+1 + feat_cache["img0"] = feat_cache.pop("next") if "next" in feat_cache else inference_model.extract_features(img0_single) + feat_cache["img1"] = inference_model.extract_features(img1_single) + feat_cache["next"] = feat_cache["img1"] + + used_multi = False + if multi_fn is not None: + # Models with timestep-independent flow can compute it once for all timesteps + try: + mids = multi_fn(img0_single, img1_single, t_values, cache=feat_cache) + result[out_idx:out_idx + num_interp] = mids[:, :, :H, :W].to(out_dtype) + out_idx += num_interp + pbar.update(num_interp) + tqdm_bar.update(num_interp) + used_multi = True + except model_management.OOM_EXCEPTION: + model_management.soft_empty_cache() + multi_fn = None # fall through to single-timestep path + + if not used_multi: + j = 0 + while j < num_interp: + b = min(batch, num_interp - j) + try: + img0 = img0_single.expand(b, -1, -1, -1) + img1 = img1_single.expand(b, -1, -1, -1) + mids = inference_model(img0, img1, timestep=ts_full[j:j + b], cache=feat_cache) + result[out_idx:out_idx + b] = mids[:, :, :H, :W].to(out_dtype) + out_idx += b + pbar.update(b) + tqdm_bar.update(b) + j += b + except model_management.OOM_EXCEPTION: + if batch <= 1: + raise + batch = max(1, batch // 2) + model_management.soft_empty_cache() + + result[out_idx] = images[i + 1].movedim(-1, 0).to(out_dtype) + out_idx += 1 + finally: + tqdm_bar.close() + + # BCHW -> BHWC + result = result.movedim(1, -1).clamp_(0.0, 1.0) + return io.NodeOutput(result) + + +class FrameInterpolationExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + FrameInterpolationModelLoader, + FrameInterpolate, + ] + + +async def comfy_entrypoint() -> FrameInterpolationExtension: + return FrameInterpolationExtension() diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index d7c2e8744..19d8a387f 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -1,6 +1,7 @@ import nodes import node_helpers import torch +import torchaudio import comfy.model_management import comfy.model_sampling import comfy.samplers @@ -711,7 +712,14 @@ class LTXVReferenceAudio(io.ComfyNode): @classmethod def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput: # Encode reference audio to latents and patchify - audio_latents = audio_vae.encode(reference_audio) + sample_rate = reference_audio["sample_rate"] + vae_sample_rate = getattr(audio_vae, "audio_sample_rate", 44100) + if vae_sample_rate != sample_rate: + waveform = torchaudio.functional.resample(reference_audio["waveform"], sample_rate, vae_sample_rate) + else: + waveform = reference_audio["waveform"] + + audio_latents = audio_vae.encode(waveform.movedim(1, -1)) b, c, t, f = audio_latents.shape ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f) ref_audio = {"tokens": ref_tokens} diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 3e4222264..3ec635c75 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -3,9 +3,8 @@ import comfy.utils import comfy.model_management import torch -from comfy.ldm.lightricks.vae.audio_vae import AudioVAE from comfy_api.latest import ComfyExtension, io - +from comfy_extras.nodes_audio import VAEEncodeAudio class LTXVAudioVAELoader(io.ComfyNode): @classmethod @@ -28,10 +27,14 @@ class LTXVAudioVAELoader(io.ComfyNode): def execute(cls, ckpt_name: str) -> io.NodeOutput: ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) - return io.NodeOutput(AudioVAE(sd, metadata)) + sd = comfy.utils.state_dict_prefix_replace(sd, {"audio_vae.": "autoencoder.", "vocoder.": "vocoder."}, filter_keys=True) + vae = comfy.sd.VAE(sd=sd, metadata=metadata) + vae.throw_exception_if_invalid() + + return io.NodeOutput(vae) -class LTXVAudioVAEEncode(io.ComfyNode): +class LTXVAudioVAEEncode(VAEEncodeAudio): @classmethod def define_schema(cls) -> io.Schema: return io.Schema( @@ -50,15 +53,8 @@ class LTXVAudioVAEEncode(io.ComfyNode): ) @classmethod - def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: - audio_latents = audio_vae.encode(audio) - return io.NodeOutput( - { - "samples": audio_latents, - "sample_rate": int(audio_vae.sample_rate), - "type": "audio", - } - ) + def execute(cls, audio, audio_vae) -> io.NodeOutput: + return super().execute(audio_vae, audio) class LTXVAudioVAEDecode(io.ComfyNode): @@ -80,12 +76,12 @@ class LTXVAudioVAEDecode(io.ComfyNode): ) @classmethod - def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + def execute(cls, samples, audio_vae) -> io.NodeOutput: audio_latent = samples["samples"] if audio_latent.is_nested: audio_latent = audio_latent.unbind()[-1] - audio = audio_vae.decode(audio_latent).to(audio_latent.device) - output_audio_sample_rate = audio_vae.output_sample_rate + audio = audio_vae.decode(audio_latent).movedim(-1, 1).to(audio_latent.device) + output_audio_sample_rate = audio_vae.first_stage_model.output_sample_rate return io.NodeOutput( { "waveform": audio, @@ -143,17 +139,17 @@ class LTXVEmptyLatentAudio(io.ComfyNode): frames_number: int, frame_rate: int, batch_size: int, - audio_vae: AudioVAE, + audio_vae, ) -> io.NodeOutput: """Generate empty audio latents matching the reference pipeline structure.""" assert audio_vae is not None, "Audio VAE model is required" z_channels = audio_vae.latent_channels - audio_freq = audio_vae.latent_frequency_bins - sampling_rate = int(audio_vae.sample_rate) + audio_freq = audio_vae.first_stage_model.latent_frequency_bins + sampling_rate = int(audio_vae.first_stage_model.sample_rate) - num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + num_audio_latents = audio_vae.first_stage_model.num_of_latents_from_frames(frames_number, frame_rate) audio_latents = torch.zeros( (batch_size, z_channels, num_audio_latents, audio_freq), diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 0a1558f2b..17e25d514 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -1,5 +1,6 @@ import json from comfy.comfy_types.node_typing import IO +import torch # Preview Any - original implement from # https://github.com/rgthree/rgthree-comfy/blob/main/py/display_any.py @@ -19,6 +20,7 @@ class PreviewAny(): SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"] def main(self, source=None): + torch.set_printoptions(edgeitems=6) value = 'None' if isinstance(source, str): value = source @@ -33,6 +35,7 @@ class PreviewAny(): except Exception: value = 'source exists, but could not be serialized.' + torch.set_printoptions() return {"ui": {"text": (value,)}, "result": (value,)} NODE_CLASS_MAPPINGS = { diff --git a/comfy_extras/nodes_sam3.py b/comfy_extras/nodes_sam3.py new file mode 100644 index 000000000..5cf92ccb3 --- /dev/null +++ b/comfy_extras/nodes_sam3.py @@ -0,0 +1,529 @@ +""" +SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking. +""" + +from typing_extensions import override + +import json +import os +import torch +import torch.nn.functional as F +import comfy.model_management +import comfy.utils +import folder_paths +from comfy_api.latest import ComfyExtension, io, ui +import av +from fractions import Fraction + + +def _extract_text_prompts(conditioning, device, dtype): + """Extract list of (text_embeddings, text_mask) from conditioning.""" + cond_meta = conditioning[0][1] + multi = cond_meta.get("sam3_multi_cond") + prompts = [] + if multi is not None: + for entry in multi: + emb = entry["cond"].to(device=device, dtype=dtype) + mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None + if mask is None: + mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) + prompts.append((emb, mask, entry.get("max_detections", 1))) + else: + emb = conditioning[0][0].to(device=device, dtype=dtype) + mask = cond_meta.get("attention_mask") + if mask is not None: + mask = mask.to(device) + else: + mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) + prompts.append((emb, mask, 1)) + return prompts + + +def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations): + """Refine a coarse detector mask via SAM decoder, cropping to the detection box. + + Returns: [1, H, W] binary mask + """ + def _coarse_fallback(): + return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), + mode="bilinear", align_corners=False)[0] > 0).float() + + if iterations <= 0: + return _coarse_fallback() + + pad_frac = 0.1 + x1, y1, x2, y2 = box_xyxy.tolist() + bw, bh = x2 - x1, y2 - y1 + cx1 = max(0, int(x1 - bw * pad_frac)) + cy1 = max(0, int(y1 - bh * pad_frac)) + cx2 = min(W, int(x2 + bw * pad_frac)) + cy2 = min(H, int(y2 + bh * pad_frac)) + if cx2 <= cx1 or cy2 <= cy1: + return _coarse_fallback() + + crop = orig_image_hwc[cy1:cy2, cx1:cx2, :3] + crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") + crop_frame = crop_1008.to(device=device, dtype=dtype) + crop_h, crop_w = cy2 - cy1, cx2 - cx1 + + # Crop coarse mask and refine via SAM on the cropped image + mask_h, mask_w = coarse_mask.shape[-2:] + mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h) + mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h) + if mx2 <= mx1 or my2 <= my1: + return _coarse_fallback() + mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0) + for _ in range(iterations): + coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False) + mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input) + + refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False) + full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype) + full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop + coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False) + return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float() + + + +class SAM3_Detect(io.ComfyNode): + """Open-vocabulary detection and segmentation using text, box, or point prompts.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_Detect", + display_name="SAM3 Detect", + category="detection/", + search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"], + inputs=[ + io.Model.Input("model", display_name="model"), + io.Image.Input("image", display_name="image"), + io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"), + io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"), + io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), + io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), + io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01), + io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"), + io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"), + ], + outputs=[ + io.Mask.Output("masks"), + io.BoundingBox.Output("bboxes"), + ], + ) + + @classmethod + def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput: + B, H, W, C = image.shape + image_in = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") + + # Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors. + # Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame). + def _boxes_to_tensor(box_list): + coords = [] + for d in box_list: + cx = (d["x"] + d["width"] / 2) / W + cy = (d["y"] + d["height"] / 2) / H + coords.append([cx, cy, d["width"] / W, d["height"] / H]) + return torch.tensor([coords], dtype=torch.float32) # [1, N, 4] + + per_frame_boxes = None + if bboxes is not None: + if isinstance(bboxes, dict): + # Single box → same for all frames + shared = _boxes_to_tensor([bboxes]) + per_frame_boxes = [shared] * B + elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list): + # list[list[dict]] → per-frame boxes + per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes] + # Pad to B if fewer frames provided + while len(per_frame_boxes) < B: + per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None) + elif isinstance(bboxes, list) and len(bboxes) > 0: + # list[dict] → same boxes for all frames + shared = _boxes_to_tensor(bboxes) + per_frame_boxes = [shared] * B + + # Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...]) + pos_pts = json.loads(positive_coords) if positive_coords else [] + neg_pts = json.loads(negative_coords) if negative_coords else [] + has_points = len(pos_pts) > 0 or len(neg_pts) > 0 + + comfy.model_management.load_model_gpu(model) + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype() + sam3_model = model.model.diffusion_model + + # Build point inputs for tracker SAM decoder path + point_inputs = None + if has_points: + all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \ + [[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts] + all_labels = [1] * len(pos_pts) + [0] * len(neg_pts) + point_inputs = { + "point_coords": torch.tensor([all_coords], dtype=dtype, device=device), + "point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device), + } + + cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else [] + has_text = len(cond_list) > 0 + + # Run per-image through detector (text/boxes) and/or tracker (points) + all_bbox_dicts = [] + all_masks = [] + pbar = comfy.utils.ProgressBar(B) + + for b in range(B): + frame = image_in[b:b+1].to(device=device, dtype=dtype) + b_boxes = None + if per_frame_boxes is not None and per_frame_boxes[b] is not None: + b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype) + + frame_bbox_dicts = [] + frame_masks = [] + + # Point prompts: tracker SAM decoder path with iterative refinement + if point_inputs is not None: + mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs) + for _ in range(max(0, refine_iterations - 1)): + mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) + mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) + frame_masks.append((mask[0] > 0).float()) + + # Box prompts: SAM decoder path (segment inside each box) + if b_boxes is not None and not has_text: + for box_cxcywh in b_boxes[0]: + cx, cy, bw, bh = box_cxcywh.tolist() + # Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners + sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008], + [(cx + bw/2) * 1008, (cy + bh/2) * 1008]]], + device=device, dtype=dtype) + mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box) + for _ in range(max(0, refine_iterations - 1)): + mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) + mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) + frame_masks.append((mask[0] > 0).float()) + + # Text prompts: run detector per text prompt (each detects one category) + for text_embeddings, text_mask, max_det in cond_list: + results = sam3_model( + frame, text_embeddings=text_embeddings, text_mask=text_mask, + boxes=b_boxes, threshold=threshold, orig_size=(H, W)) + + pred_boxes = results["boxes"][0] + scores = results["scores"][0] + masks = results["masks"][0] + + probs = scores.sigmoid() + keep = probs > threshold + kept_boxes = pred_boxes[keep].cpu() + kept_scores = probs[keep].cpu() + kept_masks = masks[keep] + + order = kept_scores.argsort(descending=True)[:max_det] + kept_boxes = kept_boxes[order] + kept_scores = kept_scores[order] + kept_masks = kept_masks[order] + + for box, score in zip(kept_boxes, kept_scores): + frame_bbox_dicts.append({ + "x": float(box[0]), "y": float(box[1]), + "width": float(box[2] - box[0]), "height": float(box[3] - box[1]), + "score": float(score), + }) + for m, box in zip(kept_masks, kept_boxes): + frame_masks.append(_refine_mask( + sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations)) + + all_bbox_dicts.append(frame_bbox_dicts) + if len(frame_masks) > 0: + combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W] + if individual_masks: + all_masks.append(combined) + else: + all_masks.append((combined > 0).any(dim=0).float()) + else: + if individual_masks: + all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device())) + else: + all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device())) + pbar.update(1) + + idev = comfy.model_management.intermediate_device() + all_masks = [m.to(idev) for m in all_masks] + mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks) + return io.NodeOutput(mask_out, all_bbox_dicts) + + +SAM3TrackData = io.Custom("SAM3_TRACK_DATA") + +class SAM3_VideoTrack(io.ComfyNode): + """Track objects across video frames using SAM3's memory-based tracker.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_VideoTrack", + display_name="SAM3 Video Track", + category="detection/", + search_aliases=["sam3", "video", "track", "propagate"], + inputs=[ + io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"), + io.Model.Input("model", display_name="model"), + io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"), + io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"), + io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"), + io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."), + io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."), + ], + outputs=[ + SAM3TrackData.Output("track_data", display_name="track_data"), + ], + ) + + @classmethod + def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput: + N, H, W, C = images.shape + + comfy.model_management.load_model_gpu(model) + device = comfy.model_management.get_torch_device() + dtype = model.model.get_dtype() + sam3_model = model.model.diffusion_model + + frames = images[..., :3].movedim(-1, 1) + frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype) + + init_masks = None + if initial_mask is not None: + init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype) + + pbar = comfy.utils.ProgressBar(N) + + text_prompts = None + if conditioning is not None and len(conditioning) > 0: + text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)] + elif initial_mask is None: + raise ValueError("Either initial_mask or conditioning must be provided") + + result = sam3_model.forward_video( + images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts, + new_det_thresh=detection_threshold, max_objects=max_objects, + detect_interval=detect_interval) + result["orig_size"] = (H, W) + return io.NodeOutput(result) + + +class SAM3_TrackPreview(io.ComfyNode): + """Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_TrackPreview", + display_name="SAM3 Track Preview", + category="detection/", + inputs=[ + SAM3TrackData.Input("track_data", display_name="track_data"), + io.Image.Input("images", display_name="images", optional=True), + io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05), + io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0), + ], + is_output_node=True, + ) + + COLORS = [ + (0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16), + (0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5), + (0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84), + ] + + # 5x3 bitmap font atlas for digits 0-9 [10, 5, 3] + _glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow) + + @staticmethod + def _get_glyphs(device, scale=3): + key = (device, scale) + if key in SAM3_TrackPreview._glyph_cache: + return SAM3_TrackPreview._glyph_cache[key] + atlas = torch.tensor([ + [[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]], + [[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]], + [[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]], + [[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]], + [[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]], + [[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]], + [[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]], + [[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]], + [[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]], + [[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]], + ], dtype=torch.bool) + glyphs, outlines = [], [] + for d in range(10): + g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1) + padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1)) + o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0) + glyphs.append(g.to(device)) + outlines.append(o.to(device)) + gh, gw = glyphs[0].shape + oh, ow = outlines[0].shape + SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow) + return SAM3_TrackPreview._glyph_cache[key] + + @staticmethod + def _draw_number_gpu(frame, number, cx, cy, color, scale=3): + """Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline.""" + H, W = frame.shape[:2] + device = frame.device + glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale) + color_t = torch.tensor(color, device=device, dtype=frame.dtype) + digs = [int(d) for d in str(number)] + total_w = len(digs) * (gw + scale) - scale + x0 = cx - total_w // 2 + y0 = cy - gh // 2 + for i, d in enumerate(digs): + dx = x0 + i * (gw + scale) + # Black outline + oy0, ox0 = y0 - 1, dx - 1 + osy1, osx1 = max(0, -oy0), max(0, -ox0) + osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0) + if osy2 > osy1 and osx2 > osx1: + fy1, fx1 = oy0 + osy1, ox0 + osx1 + frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0 + # Colored fill + sy1, sx1 = max(0, -y0), max(0, -dx) + sy2, sx2 = min(gh, H - y0), min(gw, W - dx) + if sy2 > sy1 and sx2 > sx1: + fy1, fx1 = y0 + sy1, dx + sx1 + frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t + + @classmethod + def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput: + + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + if images is not None: + H, W = images.shape[1], images.shape[2] + if packed is None: + N, N_obj = track_data["n_frames"], 0 + else: + N, N_obj = packed.shape[0], packed.shape[1] + + import uuid + gpu = comfy.model_management.get_torch_device() + temp_dir = folder_paths.get_temp_directory() + filename = f"sam3_track_preview_{uuid.uuid4().hex[:8]}.mp4" + filepath = os.path.join(temp_dir, filename) + with av.open(filepath, mode='w') as output: + stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000)) + stream.width = W + stream.height = H + stream.pix_fmt = 'yuv420p' + + frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8) + frame_np = frame_cpu.numpy() + if N_obj > 0: + colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)], + device=gpu, dtype=torch.float32) + grid_y = torch.arange(H, device=gpu).view(1, H, 1) + grid_x = torch.arange(W, device=gpu).view(1, 1, W) + for t in range(N): + if images is not None and t < images.shape[0]: + frame = images[t].clone() + else: + frame = torch.zeros(H, W, 3) + + if N_obj > 0: + frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool + frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0] + frame_gpu = frame.to(gpu) + bool_masks = frame_masks > 0.5 + any_mask = bool_masks.any(dim=0) + if any_mask.any(): + obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0) + color_overlay = colors_t[obj_idx_map] + mask_3d = any_mask.unsqueeze(-1) + frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu) + area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1) + cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area + cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area + has = area > 1 + scores = track_data.get("scores", []) + for obj_idx in range(N_obj): + if has[obj_idx]: + _cx, _cy = int(cx[obj_idx]), int(cy[obj_idx]) + color = cls.COLORS[obj_idx % len(cls.COLORS)] + SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color) + if obj_idx < len(scores) and scores[obj_idx] < 1.0: + SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100), + _cx, _cy + 5 * 3 + 3, color, scale=2) + frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte()) + else: + frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte()) + + vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24') + output.mux(stream.encode(vframe.reformat(format='yuv420p'))) + output.mux(stream.encode(None)) + return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)])) + + +class SAM3_TrackToMask(io.ComfyNode): + """Select tracked objects by index and output as mask.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SAM3_TrackToMask", + display_name="SAM3 Track to Mask", + category="detection/", + inputs=[ + SAM3TrackData.Input("track_data", display_name="track_data"), + io.String.Input("object_indices", display_name="object_indices", default="", + tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."), + ], + outputs=[ + io.Mask.Output("masks", display_name="masks"), + ], + ) + + @classmethod + def execute(cls, track_data, object_indices="") -> io.NodeOutput: + from comfy.ldm.sam3.tracker import unpack_masks + packed = track_data["packed_masks"] + H, W = track_data["orig_size"] + + if packed is None: + N = track_data["n_frames"] + return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) + + N, N_obj = packed.shape[0], packed.shape[1] + + if object_indices.strip(): + indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] + indices = [i for i in indices if 0 <= i < N_obj] + else: + indices = list(range(N_obj)) + + if not indices: + return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) + + selected = packed[:, indices] + binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool + union = binary.any(dim=1, keepdim=True).float() + mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0] + return io.NodeOutput(mask_out) + + +class SAM3Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SAM3_Detect, + SAM3_VideoTrack, + SAM3_TrackPreview, + SAM3_TrackToMask, + ] + + +async def comfy_entrypoint() -> SAM3Extension: + return SAM3Extension() diff --git a/execution.py b/execution.py index 5e02dffb2..e15eb4bda 100644 --- a/execution.py +++ b/execution.py @@ -811,11 +811,30 @@ class PromptExecutor: self._notify_prompt_lifecycle("end", prompt_id) -async def validate_inputs(prompt_id, prompt, item, validated): +async def validate_inputs(prompt_id, prompt, item, validated, visiting=None): + if visiting is None: + visiting = [] + unique_id = item if unique_id in validated: return validated[unique_id] + if unique_id in visiting: + cycle_path_nodes = visiting[visiting.index(unique_id):] + [unique_id] + cycle_nodes = list(dict.fromkeys(cycle_path_nodes)) + cycle_path = " -> ".join(f"{node_id} ({prompt[node_id]['class_type']})" for node_id in cycle_path_nodes) + for node_id in cycle_nodes: + validated[node_id] = (False, [{ + "type": "dependency_cycle", + "message": "Dependency cycle detected", + "details": cycle_path, + "extra_info": { + "node_id": node_id, + "cycle_nodes": cycle_nodes, + } + }], node_id) + return validated[unique_id] + inputs = prompt[unique_id]['inputs'] class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] @@ -899,7 +918,11 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue try: - r = await validate_inputs(prompt_id, prompt, o_id, validated) + visiting.append(unique_id) + try: + r = await validate_inputs(prompt_id, prompt, o_id, validated, visiting) + finally: + visiting.pop() if r[0] is False: # `r` will be set in `validated[o_id]` already valid = False @@ -1048,10 +1071,13 @@ async def validate_inputs(prompt_id, prompt, item, validated): errors.append(error) continue - if len(errors) > 0 or valid is not True: - ret = (False, errors, unique_id) - else: - ret = (True, [], unique_id) + ret = validated.get(unique_id, (True, [], unique_id)) + # Recursive cycle detection may have already populated an error on us. Join it. + ret = ( + ret[0] and valid is True and not errors, + ret[1] + [error for error in errors if error not in ret[1]], + unique_id, + ) validated[unique_id] = ret return ret diff --git a/folder_paths.py b/folder_paths.py index 9c96540e3..80f4b291a 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) +folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/main.py b/main.py index 12b04719d..dbaf2745c 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,8 @@ import folder_paths import time from comfy.cli_args import args, enables_dynamic_vram from app.logger import setup_logger +setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) + from app.assets.seeder import asset_seeder from app.assets.services import register_output_files import itertools @@ -27,8 +29,6 @@ if __name__ == "__main__": os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['DO_NOT_TRACK'] = '1' -setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) - faulthandler.enable(file=sys.stderr, all_threads=False) import comfy_aimdo.control diff --git a/manager_requirements.txt b/manager_requirements.txt index f770ec933..a079d3492 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1 +comfyui_manager==4.2.1 diff --git a/models/frame_interpolation/put_frame_interpolation_models_here b/models/frame_interpolation/put_frame_interpolation_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index 299b3d758..fb83da896 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,9 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", + "nodes_frame_interpolation.py", + "nodes_sam3.py" ] import_failed = [] diff --git a/openapi.yaml b/openapi.yaml new file mode 100644 index 000000000..77d0e2318 --- /dev/null +++ b/openapi.yaml @@ -0,0 +1,3231 @@ +openapi: 3.1.0 +info: + title: ComfyUI API + description: | + API for ComfyUI - A powerful and modular stable diffusion GUI and backend. + + This API allows you to interact with ComfyUI programmatically, including: + - Submitting and managing workflow executions + - Querying node/object information + - Uploading and viewing files + - Managing user settings and data + - Asset management (feature-gated) + + ## Dual-path routing + Every route registered via `self.routes` in the ComfyUI server is available at + both its bare path (e.g. `/prompt`) and an `/api`-prefixed path (e.g. `/api/prompt`). + This spec uses the `/api`-prefixed versions as canonical. + + ## Multi-user mode + When ComfyUI is started with `--multi-user`, the `Comfy-User` header identifies + the active user for settings, userdata, and history isolation. This is **not** a + security mechanism — it is an organisational convenience with no authentication + or authorisation behind it. + version: 1.0.0 + license: + name: GNU General Public License v3.0 + url: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE + +servers: + - url: / + description: Default ComfyUI server (typically http://127.0.0.1:8188) + +tags: + - name: prompt + description: Workflow submission and prompt info + - name: queue + description: Queue inspection and management + - name: history + description: Execution history + - name: upload + description: File upload endpoints + - name: view + description: File viewing / download + - name: system + description: System stats and feature flags + - name: node + description: Node / object_info definitions + - name: model + description: Model folder and file listing + - name: user + description: User management (multi-user mode) + - name: userdata + description: Per-user file storage + - name: settings + description: Per-user settings + - name: extensions + description: Frontend extension JS files + - name: subgraph + description: Global subgraph blueprints + - name: internal + description: Internal / debug endpoints + - name: assets + description: Asset management (feature-gated behind enable-assets) + +paths: + # --------------------------------------------------------------------------- + # WebSocket + # --------------------------------------------------------------------------- + /ws: + get: + operationId: connectWebSocket + tags: [system] + summary: WebSocket connection for real-time updates + description: | + Upgrades to a WebSocket connection that streams execution progress, + node status, and output messages. The server sends an initial `status` + message with the session ID (SID) on connect. + + ## Message types (server → client) + The server sends JSON messages with a `type` field. See the + `x-websocket-messages` list below for the schema of each message type. + parameters: + - name: clientId + in: query + required: false + schema: + type: string + description: Client identifier. If omitted the server assigns one. + responses: + "101": + description: WebSocket upgrade successful + x-websocket-messages: + - type: status + schema: + $ref: "#/components/schemas/StatusWsMessage" + - type: progress + schema: + $ref: "#/components/schemas/ProgressWsMessage" + - type: progress_text + schema: + $ref: "#/components/schemas/ProgressTextWsMessage" + - type: progress_state + schema: + $ref: "#/components/schemas/ProgressStateWsMessage" + - type: executing + schema: + $ref: "#/components/schemas/ExecutingWsMessage" + - type: executed + schema: + $ref: "#/components/schemas/ExecutedWsMessage" + - type: execution_start + schema: + $ref: "#/components/schemas/ExecutionStartWsMessage" + - type: execution_success + schema: + $ref: "#/components/schemas/ExecutionSuccessWsMessage" + - type: execution_cached + schema: + $ref: "#/components/schemas/ExecutionCachedWsMessage" + - type: execution_interrupted + schema: + $ref: "#/components/schemas/ExecutionInterruptedWsMessage" + - type: execution_error + schema: + $ref: "#/components/schemas/ExecutionErrorWsMessage" + - type: logs + schema: + $ref: "#/components/schemas/LogsWsMessage" + - type: notification + schema: + $ref: "#/components/schemas/NotificationWsMessage" + - type: feature_flags + schema: + $ref: "#/components/schemas/FeatureFlagsWsMessage" + - type: asset_download + schema: + $ref: "#/components/schemas/AssetDownloadWsMessage" + - type: asset_export + schema: + $ref: "#/components/schemas/AssetExportWsMessage" + + # --------------------------------------------------------------------------- + # Prompt + # --------------------------------------------------------------------------- + /api/prompt: + get: + operationId: getPromptInfo + tags: [prompt] + summary: Get queue status + description: Returns how many items remain in the execution queue. + responses: + "200": + description: Queue info + content: + application/json: + schema: + $ref: "#/components/schemas/PromptInfo" + post: + operationId: executePrompt + tags: [prompt] + summary: Submit a workflow for execution + description: Submits a workflow for execution. The server validates the graph, assigns a `prompt_id`, and enqueues it. Clients listen on `/ws` for execution progress and output messages. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/PromptRequest" + responses: + "200": + description: Prompt accepted + content: + application/json: + schema: + $ref: "#/components/schemas/PromptResponse" + "400": + description: Validation or node errors + content: + application/json: + schema: + $ref: "#/components/schemas/PromptErrorResponse" + + # --------------------------------------------------------------------------- + # Queue + # --------------------------------------------------------------------------- + /api/queue: + get: + operationId: getQueue + tags: [queue] + summary: Get running and pending queue items + description: Returns the server's current execution queue, split into the currently-running prompt and the list of pending prompts. + responses: + "200": + description: Queue contents + content: + application/json: + schema: + $ref: "#/components/schemas/QueueInfo" + post: + operationId: manageQueue + tags: [queue] + summary: Clear or delete items from the queue + description: Mutates the execution queue. Supports clearing all queued prompts or deleting individual prompts by ID. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/QueueManageRequest" + responses: + "200": + description: Queue updated + + /api/interrupt: + post: + operationId: interruptExecution + tags: [queue] + summary: Interrupt current execution + description: Interrupts the prompt that is currently executing. The next queued prompt (if any) will start immediately after. + requestBody: + required: false + content: + application/json: + schema: + type: object + properties: + prompt_id: + type: string + format: uuid + description: "If provided, only interrupts this specific running prompt. Otherwise interrupts all." + responses: + "200": + description: Interrupt signal sent + + /api/free: + post: + operationId: freeMemory + tags: [queue] + summary: Free GPU memory and/or unload models + description: Frees GPU memory by unloading models and/or freeing the resident model cache, controlled by the request flags. + requestBody: + required: false + content: + application/json: + schema: + type: object + properties: + unload_models: + type: boolean + description: Unload all models from VRAM/RAM + free_memory: + type: boolean + description: Run garbage collection and free cached memory + responses: + "200": + description: Memory freed + + # --------------------------------------------------------------------------- + # Jobs + # --------------------------------------------------------------------------- + /api/jobs: + get: + operationId: listJobs + tags: [queue] + summary: List jobs with filtering and pagination + description: Returns a paginated list of completed prompt executions, newest first. + parameters: + - name: status + in: query + schema: + type: string + description: Filter by job status + - name: workflow_id + in: query + schema: + type: string + description: Filter by workflow ID + - name: sort_by + in: query + schema: + type: string + description: Field to sort by + - name: sort_order + in: query + schema: + type: string + enum: [asc, desc] + description: Sort direction + - name: limit + in: query + schema: + type: integer + description: Maximum number of results (default is unlimited/None) + - name: offset + in: query + schema: + type: integer + default: 0 + description: Pagination offset + responses: + "200": + description: Jobs list + content: + application/json: + schema: + type: object + properties: + jobs: + type: array + items: + $ref: "#/components/schemas/JobEntry" + pagination: + $ref: "#/components/schemas/PaginationInfo" + + /api/jobs/{job_id}: + get: + operationId: getJob + tags: [queue] + summary: Get a single job by ID + description: Returns the full record for a single completed prompt execution, including its outputs, status, and metadata. + parameters: + - name: job_id + in: path + description: The job (prompt) ID to fetch. + required: true + schema: + type: string + format: uuid + responses: + "200": + description: Job detail + content: + application/json: + schema: + $ref: "#/components/schemas/JobDetailResponse" + "404": + description: Job not found + + # --------------------------------------------------------------------------- + # History + # --------------------------------------------------------------------------- + /api/history: + get: + operationId: getHistory + tags: [history] + summary: Get execution history + deprecated: true + description: | + **Deprecated.** Superseded by `GET /api/jobs`, which returns the same + execution records in a paginated, filterable format. Planned for removal + no earlier than a future major release; sunset timeline TBD. + + Returns a dictionary keyed by prompt_id. Each value is a HistoryEntry + containing prompt metadata, outputs, status, and node meta. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: max_items + in: query + schema: + type: integer + description: Maximum number of history entries to return + - name: offset + in: query + schema: + type: integer + description: Pagination offset (number of entries to skip) + responses: + "200": + description: History dictionary keyed by prompt_id + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/HistoryEntry" + post: + operationId: manageHistory + tags: [history] + summary: Clear or delete history entries + deprecated: true + description: | + **Deprecated.** Superseded by the forthcoming job-management endpoints + under `/api/jobs`. Planned for removal no earlier than a future major + release; sunset timeline TBD. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/HistoryManageRequest" + responses: + "200": + description: History updated + + /api/history/{prompt_id}: + get: + operationId: getHistoryByPromptId + tags: [history] + summary: Get history for a specific prompt + deprecated: true + description: | + **Deprecated.** Superseded by `GET /api/jobs/{job_id}`, which returns + the same execution record. Planned for removal no earlier than a future + major release; sunset timeline TBD. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: prompt_id + in: path + description: The prompt ID to fetch history for. + required: true + schema: + type: string + format: uuid + responses: + "200": + description: Single-entry history dictionary. Returns an empty object `{}` if the prompt_id is not found. + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/HistoryEntry" + + # --------------------------------------------------------------------------- + # Upload + # --------------------------------------------------------------------------- + /api/upload/image: + post: + operationId: uploadImage + tags: [upload] + summary: Upload an image file + description: Uploads an image file into one of the input/output/temp directories so it can be referenced by workflow nodes. + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - image + properties: + image: + type: string + format: binary + description: Image file to upload + type: + type: string + enum: [input, temp, output] + default: input + description: Target directory type + overwrite: + type: string + description: 'Set to "true" to overwrite existing files' + subfolder: + type: string + description: Subfolder within the target directory + responses: + "200": + description: Upload result + content: + application/json: + schema: + $ref: "#/components/schemas/UploadResult" + "400": + description: No file provided or invalid request + + /api/upload/mask: + post: + operationId: uploadMask + tags: [upload] + summary: Upload a mask image + description: Uploads a mask image associated with a previously-uploaded reference image. + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - image + - original_ref + properties: + image: + type: string + format: binary + description: Mask image (alpha channel is used) + original_ref: + type: object + description: Reference to the original image file + required: + - filename + properties: + filename: + type: string + description: Filename of the original image + additionalProperties: true + type: + type: string + enum: [input, temp, output] + default: input + description: Target directory type + overwrite: + type: string + description: 'Set to "true" to overwrite existing files' + subfolder: + type: string + description: Subfolder within the target directory + responses: + "200": + description: Upload result + content: + application/json: + schema: + $ref: "#/components/schemas/UploadResult" + "400": + description: No file provided or invalid request + + # --------------------------------------------------------------------------- + # View + # --------------------------------------------------------------------------- + /api/view: + get: + operationId: viewFile + tags: [view] + summary: View or download a file + description: Serves a file (image, audio, or video) from the input/output/temp directory identified by the query parameters. + parameters: + - name: filename + in: query + required: true + schema: + type: string + description: Name of the file to view + - name: type + in: query + schema: + type: string + enum: [input, output, temp] + default: output + description: Directory type + - name: subfolder + in: query + schema: + type: string + description: Subfolder within the directory + - name: preview + in: query + schema: + type: string + description: Preview format hint (e.g. "webp;90") + - name: channel + in: query + schema: + type: string + enum: [rgba, rgb, a] + description: Channel extraction mode + responses: + "200": + description: File content + content: + image/*: + schema: + type: string + format: binary + video/*: + schema: + type: string + format: binary + audio/*: + schema: + type: string + format: binary + application/octet-stream: + schema: + type: string + format: binary + "404": + description: File not found + + /api/view_metadata/{folder_name}: + get: + operationId: viewMetadata + tags: [view] + summary: Get metadata for a file (e.g. safetensors header) + description: Returns embedded metadata parsed from a file in the given folder — for example, the header of a safetensors model. + parameters: + - name: folder_name + in: path + required: true + schema: + type: string + description: Folder type (output, input, temp, etc.) + - name: filename + in: query + required: true + schema: + type: string + description: Filename to read metadata from + responses: + "200": + description: File metadata + content: + application/json: + schema: + type: object + additionalProperties: true + "404": + description: File or metadata not found + + # --------------------------------------------------------------------------- + # System + # --------------------------------------------------------------------------- + /api/system_stats: + get: + operationId: getSystemStats + tags: [system] + summary: Get system statistics + description: Returns hardware, Python, VRAM, and runtime statistics for the running ComfyUI process. + responses: + "200": + description: System stats + content: + application/json: + schema: + $ref: "#/components/schemas/SystemStatsResponse" + + /api/features: + get: + operationId: getFeatures + tags: [system] + summary: Get enabled feature flags + description: Returns a dictionary of feature flag names to their enabled state. + responses: + "200": + description: Feature flags + content: + application/json: + schema: + type: object + additionalProperties: + type: boolean + + # --------------------------------------------------------------------------- + # Node / Object Info + # --------------------------------------------------------------------------- + /api/object_info: + get: + operationId: getObjectInfo + tags: [node] + summary: Get all node definitions + description: | + Returns a dictionary of every registered node class, keyed by class name. + Each value is a NodeInfo object describing inputs, outputs, category, etc. + responses: + "200": + description: All node definitions + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/NodeInfo" + + /api/object_info/{node_class}: + get: + operationId: getObjectInfoByClass + tags: [node] + summary: Get a single node definition + description: Returns the `NodeInfo` definition for a single registered node class. + parameters: + - name: node_class + in: path + required: true + schema: + type: string + description: Node class name (e.g. "KSampler") + responses: + "200": + description: Single node definition + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/NodeInfo" + "404": + description: Node class not found + + /api/embeddings: + get: + operationId: getEmbeddings + tags: [node] + summary: List available embedding names + description: Returns the list of text-encoder embeddings available on disk. + responses: + "200": + description: Embedding names + content: + application/json: + schema: + type: array + items: + type: string + + # --------------------------------------------------------------------------- + # Models + # --------------------------------------------------------------------------- + /api/models: + get: + operationId: getModelTypes + tags: [model] + summary: List model folder type names + description: Returns an array of model type names (e.g. checkpoints, loras, vae). + responses: + "200": + description: Model type names + content: + application/json: + schema: + type: array + items: + type: string + + /api/models/{folder}: + get: + operationId: getModelsByFolder + tags: [model] + summary: List model filenames in a folder + description: Returns the names of model files in the given folder. This endpoint predates `/api/experiment/models/{folder}` and returns names only — prefer the experiment endpoint for new integrations. + parameters: + - name: folder + in: path + required: true + schema: + type: string + description: Model folder type name + responses: + "200": + description: Model filenames + content: + application/json: + schema: + type: array + items: + type: string + "404": + description: Unknown folder type + + /api/experiment/models: + get: + operationId: getExperimentModels + tags: [model] + summary: List model folders with paths + description: Returns an array of model folder objects with name and folder paths. + responses: + "200": + description: Model folders + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/ModelFolder" + + /api/experiment/models/{folder}: + get: + operationId: getExperimentModelsByFolder + tags: [model] + summary: List model files with metadata + description: Returns the model files in the given folder with richer metadata (path index, mtime, size) than the legacy `/api/models/{folder}` endpoint. + parameters: + - name: folder + in: path + required: true + schema: + type: string + description: Model folder type name + responses: + "200": + description: Model files with metadata + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/ModelFile" + "404": + description: Unknown folder type + + /api/experiment/models/preview/{folder}/{path_index}/{filename}: + get: + operationId: getModelPreview + tags: [model] + summary: Get model preview image + description: Returns the preview image associated with a model file, if one exists alongside the model on disk. + parameters: + - name: folder + in: path + required: true + schema: + type: string + description: Model folder type name + - name: path_index + in: path + required: true + schema: + type: integer + description: Path index within the folder + - name: filename + in: path + required: true + schema: + type: string + description: Model filename + responses: + "200": + description: Preview image (WebP) + content: + image/webp: + schema: + type: string + format: binary + "404": + description: Preview not found + + # --------------------------------------------------------------------------- + # Users + # --------------------------------------------------------------------------- + /api/users: + get: + operationId: getUsers + tags: [user] + summary: Get user storage info + description: | + Returns user storage configuration. In single-user mode returns + `{"storage": "server", "migrated": true/false}`. In multi-user mode + returns `{"storage": "server", "users": {"user_id": "user_dir", ...}}`. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + responses: + "200": + description: User info + content: + application/json: + schema: + type: object + properties: + storage: + type: string + description: Storage backend type (always "server") + migrated: + type: boolean + description: Whether migration from browser storage is complete (single-user) + users: + type: object + additionalProperties: + type: string + description: Map of user_id to directory name (multi-user) + post: + operationId: createUser + tags: [user] + summary: Create a new user (multi-user mode) + description: Creates a new user entry. Only meaningful when ComfyUI is running in multi-user mode. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - username + properties: + username: + type: string + description: Username for the new user + responses: + "200": + description: Created user ID + content: + application/json: + schema: + type: string + description: The generated user_id + "400": + description: Username already exists or invalid + + # --------------------------------------------------------------------------- + # Userdata + # --------------------------------------------------------------------------- + /api/userdata: + get: + operationId: listUserdata + tags: [userdata] + summary: List files in a userdata directory + description: Lists files in the authenticated user's data directory. Returns either filename strings or full objects depending on the `full_info` query parameter. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: dir + in: query + required: true + schema: + type: string + description: Directory path relative to the user's data folder + - name: recurse + in: query + schema: + type: boolean + description: Recurse into subdirectories + - name: full_info + in: query + schema: + type: boolean + description: Return full file info objects instead of just names + - name: split + in: query + schema: + type: boolean + description: Split paths into directory components + responses: + "200": + description: File listing + content: + application/json: + schema: + $ref: "#/components/schemas/ListUserdataResponse" + "404": + description: Directory not found + + /api/v2/userdata: + get: + operationId: listUserdataV2 + tags: [userdata] + summary: List files in userdata (v2 format) + description: Lists files in the authenticated user's data directory using the v2 response shape, which always returns full objects. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: path + in: query + schema: + type: string + description: Directory path relative to user data root + responses: + "200": + description: File listing with metadata + content: + application/json: + schema: + type: array + items: + type: object + properties: + name: + type: string + path: + type: string + type: + type: string + enum: [file, directory] + size: + type: integer + modified: + type: number + description: Unix timestamp + + /api/userdata/{file}: + get: + operationId: getUserdataFile + tags: [userdata] + summary: Read a userdata file + description: Reads the contents of a file from the authenticated user's data directory. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: file + in: path + required: true + schema: + type: string + description: File path relative to user data directory + responses: + "200": + description: File content + content: + application/octet-stream: + schema: + type: string + format: binary + "404": + description: File not found + post: + operationId: writeUserdataFile + tags: [userdata] + summary: Write or create a userdata file + description: Writes (creates or replaces) a file in the authenticated user's data directory. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: file + in: path + required: true + schema: + type: string + description: File path relative to user data directory + - name: overwrite + in: query + schema: + type: boolean + description: Allow overwriting existing files + - name: full_info + in: query + schema: + type: boolean + description: Return full file info in response + requestBody: + required: true + content: + application/octet-stream: + schema: + type: string + format: binary + application/json: + schema: {} + responses: + "200": + description: File written + content: + application/json: + schema: + $ref: "#/components/schemas/UserDataResponse" + "409": + description: File exists and overwrite not set + delete: + operationId: deleteUserdataFile + tags: [userdata] + summary: Delete a userdata file + description: Deletes a file from the authenticated user's data directory. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: file + in: path + required: true + schema: + type: string + description: File path relative to user data directory + responses: + "204": + description: File deleted + "404": + description: File not found + + /api/userdata/{file}/move/{dest}: + post: + operationId: moveUserdataFile + tags: [userdata] + summary: Move or rename a userdata file + description: Renames or moves a file within the authenticated user's data directory. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: file + in: path + required: true + schema: + type: string + description: Source file path + - name: dest + in: path + required: true + schema: + type: string + description: Destination file path + - name: overwrite + in: query + schema: + type: boolean + description: Allow overwriting at destination + - name: full_info + in: query + schema: + type: boolean + description: Return full file info in response + responses: + "200": + description: File moved + content: + application/json: + schema: + $ref: "#/components/schemas/UserDataResponse" + "404": + description: Source file not found + "409": + description: Destination exists and overwrite not set + + # --------------------------------------------------------------------------- + # Settings + # --------------------------------------------------------------------------- + /api/settings: + get: + operationId: getSettings + tags: [settings] + summary: Get all user settings + description: Returns all settings for the authenticated user. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + responses: + "200": + description: Settings object + content: + application/json: + schema: + type: object + additionalProperties: true + post: + operationId: updateSettings + tags: [settings] + summary: Update user settings (partial merge) + description: Replaces the authenticated user's settings with the provided object. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + requestBody: + required: true + content: + application/json: + schema: + type: object + additionalProperties: true + description: Partial settings to merge + responses: + "200": + description: Settings updated + + /api/settings/{id}: + get: + operationId: getSetting + tags: [settings] + summary: Get a single setting by key + description: Returns the value of a single setting, identified by key. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: id + in: path + required: true + schema: + type: string + description: Setting key + responses: + "200": + description: Setting value (null if the setting does not exist) + content: + application/json: + schema: + nullable: true + description: The setting value (any JSON type), or null if not set + post: + operationId: updateSetting + tags: [settings] + summary: Set a single setting value + description: Sets the value of a single setting, identified by key. + parameters: + - $ref: "#/components/parameters/ComfyUserHeader" + - name: id + in: path + required: true + schema: + type: string + description: Setting key + requestBody: + required: true + content: + application/json: + schema: + description: The setting value (any JSON type) + responses: + "200": + description: Setting updated + + # --------------------------------------------------------------------------- + # Extensions / Templates / i18n + # --------------------------------------------------------------------------- + /api/extensions: + get: + operationId: getExtensions + tags: [extensions] + summary: List frontend extension JS file paths + description: Returns the list of frontend extension JS URLs registered by custom nodes, to be loaded by the frontend on startup. + responses: + "200": + description: Array of JS file paths + content: + application/json: + schema: + type: array + items: + type: string + description: Relative path to extension JS file + + /api/workflow_templates: + get: + operationId: getWorkflowTemplates + tags: [extensions] + summary: Get workflow template mappings + description: Returns a map of custom node names to their provided workflow template names. + responses: + "200": + description: Template mappings + content: + application/json: + schema: + type: object + additionalProperties: + type: array + items: + type: string + description: Map of node pack name to array of template names + + /api/i18n: + get: + operationId: getI18n + tags: [extensions] + summary: Get internationalisation translation strings + description: Returns the URLs of translation files contributed by custom nodes, keyed by locale. + responses: + "200": + description: Translation map + content: + application/json: + schema: + type: object + additionalProperties: true + description: Nested map of locale to translation key-value pairs + + # --------------------------------------------------------------------------- + # Subgraphs + # --------------------------------------------------------------------------- + /api/global_subgraphs: + get: + operationId: getGlobalSubgraphs + tags: [subgraph] + summary: List global subgraph blueprints + description: Returns a dictionary of subgraph IDs to their metadata. + responses: + "200": + description: Subgraph metadata dictionary + content: + application/json: + schema: + type: object + additionalProperties: + $ref: "#/components/schemas/GlobalSubgraphInfo" + + /api/global_subgraphs/{id}: + get: + operationId: getGlobalSubgraph + tags: [subgraph] + summary: Get a global subgraph with full data + description: Returns the blueprint for a globally-registered subgraph, used by the frontend to materialize the subgraph node. + parameters: + - name: id + in: path + required: true + schema: + type: string + description: Subgraph identifier + responses: + "200": + description: Full subgraph data + content: + application/json: + schema: + $ref: "#/components/schemas/GlobalSubgraphData" + "404": + description: Subgraph not found + + # --------------------------------------------------------------------------- + # Node Replacements + # --------------------------------------------------------------------------- + /api/node_replacements: + get: + operationId: getNodeReplacements + tags: [node] + summary: Get node replacement mappings + description: | + Returns a dictionary mapping deprecated or replaced node class names + to their replacement node information. + responses: + "200": + description: Replacement mappings + content: + application/json: + schema: + type: object + additionalProperties: true + + # --------------------------------------------------------------------------- + # Internal (x-internal: true) + # --------------------------------------------------------------------------- + /internal/logs: + get: + operationId: getInternalLogs + tags: [internal] + summary: Get server logs as text + description: Returns structured ComfyUI log entries from the in-memory log buffer. + x-internal: true + responses: + "200": + description: Log text + content: + text/plain: + schema: + type: string + + /internal/logs/raw: + get: + operationId: getInternalLogsRaw + tags: [internal] + summary: Get raw structured log entries + description: Returns the raw ComfyUI log buffer as text, together with metadata about the current size limit. + x-internal: true + responses: + "200": + description: Structured log data + content: + application/json: + schema: + type: object + properties: + entries: + type: array + items: + type: object + properties: + t: + type: number + description: Timestamp + m: + type: string + description: Message + size: + type: object + properties: + cols: + type: integer + rows: + type: integer + + /internal/logs/subscribe: + patch: + operationId: subscribeToLogs + tags: [internal] + summary: Subscribe or unsubscribe a WebSocket client to log streaming + description: Subscribes or unsubscribes the current client from live log streaming over the WebSocket. + x-internal: true + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - clientId + - enabled + properties: + clientId: + type: string + description: WebSocket client ID + enabled: + type: boolean + description: Enable or disable log streaming for this client + responses: + "200": + description: Subscription updated + + /internal/folder_paths: + get: + operationId: getInternalFolderPaths + tags: [internal] + summary: Get configured folder paths + description: Returns the filesystem paths ComfyUI is configured to load models and other assets from, keyed by folder type. + x-internal: true + responses: + "200": + description: Dictionary of folder type to paths + content: + application/json: + schema: + type: object + additionalProperties: + type: array + items: + type: array + items: + type: string + description: Map of folder type name to list of [path, ...] entries + + /internal/files/{directory_type}: + get: + operationId: getInternalFiles + tags: [internal] + summary: List files in a directory type + description: Lists the files present in one of ComfyUI's known directories (input, output, or temp). + x-internal: true + parameters: + - name: directory_type + in: path + required: true + schema: + type: string + description: Directory type (e.g. output, input, temp) + responses: + "200": + description: Array of filenames + content: + application/json: + schema: + type: array + items: + type: string + + # --------------------------------------------------------------------------- + # Assets (x-feature-gate: enable-assets) + # --------------------------------------------------------------------------- + /api/assets/hash/{hash}: + head: + operationId: checkAssetByHash + tags: [assets] + summary: Check if an asset with the given hash exists + description: Returns 204 if an asset with the given content hash already exists, 404 otherwise. Used by clients to deduplicate uploads before transferring bytes. + x-feature-gate: enable-assets + parameters: + - name: hash + in: path + required: true + schema: + type: string + description: "Blake3 hash of the asset (e.g. blake3:abc123...)" + responses: + "200": + description: Asset exists + "404": + description: No asset with this hash + + /api/assets: + get: + operationId: listAssets + tags: [assets] + summary: List assets with filtering and pagination + description: Returns a paginated list of assets, optionally filtered by tags, name, or other query parameters. + x-feature-gate: enable-assets + parameters: + - name: limit + in: query + schema: + type: integer + default: 50 + - name: offset + in: query + schema: + type: integer + default: 0 + - name: include_tags + in: query + schema: + type: array + items: + type: string + style: form + explode: true + description: Tags that assets must have (AND logic) + - name: exclude_tags + in: query + schema: + type: array + items: + type: string + style: form + explode: true + description: Tags that assets must not have + - name: name_contains + in: query + schema: + type: string + description: Filter assets whose name contains this substring + - name: metadata_filter + in: query + schema: + type: string + description: JSON-encoded metadata key/value filter + - name: sort + in: query + schema: + type: string + description: Field to sort by + - name: order + in: query + schema: + type: string + enum: [asc, desc] + description: Sort direction + responses: + "200": + description: Asset list + content: + application/json: + schema: + $ref: "#/components/schemas/ListAssetsResponse" + post: + operationId: createAsset + tags: [assets] + summary: Upload a new asset + description: Uploads a new asset (binary content plus metadata) and registers it in the asset database. + x-feature-gate: enable-assets + requestBody: + required: true + content: + multipart/form-data: + schema: + type: object + required: + - file + properties: + file: + type: string + format: binary + description: Asset file to upload + name: + type: string + description: Display name for the asset + tags: + type: string + description: Comma-separated tags + user_metadata: + type: string + description: JSON-encoded user metadata + hash: + type: string + description: "Blake3 hash of the file content (e.g. blake3:abc123...)" + mime_type: + type: string + description: MIME type of the file (overrides auto-detected type) + preview_id: + type: string + format: uuid + description: ID of an existing asset to use as the preview image + responses: + "201": + description: Asset created + content: + application/json: + schema: + $ref: "#/components/schemas/AssetCreated" + + /api/assets/from-hash: + post: + operationId: createAssetFromHash + tags: [assets] + summary: Create an asset reference from an existing hash + description: Registers a new asset that references existing content by hash, without re-uploading the bytes. + x-feature-gate: enable-assets + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - hash + - name + properties: + hash: + type: string + description: Blake3 hash of existing content + name: + type: string + description: Display name + tags: + type: array + items: + type: string + user_metadata: + type: object + additionalProperties: true + responses: + "201": + description: Asset created from hash + content: + application/json: + schema: + $ref: "#/components/schemas/AssetCreated" + + /api/assets/{id}: + get: + operationId: getAsset + tags: [assets] + summary: Get asset metadata + description: Returns the metadata for a single asset. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + responses: + "200": + description: Asset metadata + content: + application/json: + schema: + $ref: "#/components/schemas/Asset" + "404": + description: Asset not found + put: + operationId: updateAsset + tags: [assets] + summary: Update asset metadata + description: Updates the mutable metadata of an asset (name, tags, etc.). Binary content is immutable. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + requestBody: + required: true + content: + application/json: + schema: + type: object + properties: + name: + type: string + description: New display name for the asset + user_metadata: + type: object + additionalProperties: true + description: Custom user metadata to set + preview_id: + type: string + format: uuid + description: ID of the asset to use as the preview + responses: + "200": + description: Asset updated + content: + application/json: + schema: + $ref: "#/components/schemas/AssetUpdated" + delete: + operationId: deleteAsset + tags: [assets] + summary: Delete an asset + description: Removes an asset entry. Depending on the server configuration, the underlying content may also be deleted. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + - name: delete_content + in: query + schema: + type: boolean + description: Also delete the underlying content file + responses: + "204": + description: Asset deleted + + /api/assets/{id}/content: + get: + operationId: getAssetContent + tags: [assets] + summary: Download asset file content + description: Returns the binary content of an asset. Supports range requests. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + responses: + "200": + description: Asset file content + content: + application/octet-stream: + schema: + type: string + format: binary + "404": + description: Asset not found + + /api/assets/{id}/tags: + post: + operationId: addAssetTags + tags: [assets] + summary: Add tags to an asset + description: Adds one or more tags to an asset. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - tags + properties: + tags: + type: array + items: + type: string + responses: + "200": + description: Tags added + content: + application/json: + schema: + $ref: "#/components/schemas/TagsModificationResponse" + delete: + operationId: removeAssetTags + tags: [assets] + summary: Remove tags from an asset + description: Removes one or more tags from an asset. + x-feature-gate: enable-assets + parameters: + - name: id + in: path + description: The asset ID. + required: true + schema: + type: string + format: uuid + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - tags + properties: + tags: + type: array + items: + type: string + responses: + "200": + description: Tags removed + content: + application/json: + schema: + $ref: "#/components/schemas/TagsModificationResponse" + + /api/tags: + get: + operationId: listTags + tags: [assets] + summary: List all known tags with counts + description: Returns the list of all tags known to the asset database, with counts. + x-feature-gate: enable-assets + parameters: + - name: limit + in: query + schema: + type: integer + - name: offset + in: query + schema: + type: integer + - name: search + in: query + schema: + type: string + description: Search term for tag name + responses: + "200": + description: Tag list + content: + application/json: + schema: + $ref: "#/components/schemas/ListTagsResponse" + + /api/assets/tags/refine: + get: + operationId: refineAssetTags + tags: [assets] + summary: Get tag counts for assets matching current filters + description: Returns suggested additional tags that would refine a filtered asset query, together with the count of assets each tag would select. + x-feature-gate: enable-assets + parameters: + - name: include_tags + in: query + schema: + type: array + items: + type: string + style: form + explode: true + description: Tags that assets must have (AND logic) + - name: exclude_tags + in: query + schema: + type: array + items: + type: string + style: form + explode: true + description: Tags that assets must not have + - name: name_contains + in: query + schema: + type: string + description: Filter assets whose name contains this substring + - name: metadata_filter + in: query + schema: + type: string + description: JSON-encoded metadata key/value filter + - name: limit + in: query + schema: + type: integer + - name: offset + in: query + schema: + type: integer + - name: sort + in: query + schema: + type: string + description: Field to sort by + - name: order + in: query + schema: + type: string + enum: [asc, desc] + description: Sort direction + responses: + "200": + description: Tag histogram + content: + application/json: + schema: + $ref: "#/components/schemas/AssetTagHistogramResponse" + + /api/assets/seed: + post: + operationId: seedAssets + tags: [assets] + summary: Trigger asset scan/seed from filesystem + description: Starts a background job that scans the configured directories and registers any assets not yet present in the asset database. + x-feature-gate: enable-assets + requestBody: + required: false + content: + application/json: + schema: + type: object + properties: + roots: + type: array + items: + type: string + description: Root folder paths to scan (if omitted, scans all) + responses: + "200": + description: Seed started + content: + application/json: + schema: + type: object + properties: + status: + type: string + + /api/assets/seed/status: + get: + operationId: getAssetSeedStatus + tags: [assets] + summary: Get asset scan progress + description: Returns the progress and status of the most recently-started asset seed job. + x-feature-gate: enable-assets + responses: + "200": + description: Scan progress + content: + application/json: + schema: + type: object + additionalProperties: true + description: Scan progress details (files scanned, total, status, etc.) + + /api/assets/seed/cancel: + post: + operationId: cancelAssetSeed + tags: [assets] + summary: Cancel an in-progress asset scan + description: Requests cancellation of the currently-running asset seed job. + x-feature-gate: enable-assets + responses: + "200": + description: Scan cancelled + content: + application/json: + schema: + type: object + properties: + status: + type: string + + /api/assets/prune: + post: + operationId: pruneAssets + tags: [assets] + summary: Mark assets whose backing files no longer exist on disk + description: Starts a background job that removes asset entries whose underlying content no longer exists on disk. + x-feature-gate: enable-assets + responses: + "200": + description: Prune result + content: + application/json: + schema: + type: object + properties: + status: + type: string + marked: + type: integer + description: Number of assets marked as missing + +components: + parameters: + ComfyUserHeader: + name: Comfy-User + in: header + required: false + schema: + type: string + description: | + Identifies the active user in multi-user mode. Used for settings, + userdata, and history isolation. This is not a security mechanism — + it is an organisational convenience with no authentication behind it. + + schemas: + # ------------------------------------------------------------------- + # Prompt + # ------------------------------------------------------------------- + PromptRequest: + type: object + description: A workflow submission. Wraps the prompt graph plus optional client identifier and extra per-request data. + required: + - prompt + properties: + prompt: + type: object + description: | + The workflow graph to execute. Keys are node IDs (strings); + values are objects with class_type and inputs. + additionalProperties: true + number: + type: number + description: Priority number for the queue (lower numbers have higher priority) + front: + type: boolean + description: If true, adds the prompt to the front of the queue + extra_data: + type: object + description: Extra data associated with the prompt (e.g. extra_pnginfo) + additionalProperties: true + client_id: + type: string + description: WebSocket client ID to receive progress updates + prompt_id: + type: string + format: uuid + description: "Client-supplied prompt ID. Server generates a UUID if omitted." + partial_execution_targets: + type: array + items: + type: string + description: List of node IDs to execute (partial graph execution) + + PromptResponse: + type: object + description: Server acknowledgement of a workflow submission. Includes the assigned `prompt_id` and current queue position. + properties: + prompt_id: + type: string + format: uuid + description: Unique identifier for the prompt execution + number: + type: number + description: Priority number in the queue + node_errors: + type: object + description: Validation errors keyed by node ID + additionalProperties: + $ref: "#/components/schemas/NodeError" + error: + description: Top-level prompt error (string message or structured error) + oneOf: + - type: string + - $ref: "#/components/schemas/PromptError" + + PromptErrorResponse: + type: object + description: Error response when prompt validation fails + additionalProperties: true + + PromptError: + type: object + description: Structured prompt validation error + properties: + type: + type: string + message: + type: string + details: + type: string + + Error: + type: object + description: Detailed node-level error + properties: + type: + type: string + message: + type: string + details: + type: string + extra_info: + type: object + properties: + input_name: + type: string + additionalProperties: true + + NodeError: + type: object + description: Error details for a single node + properties: + errors: + type: array + items: + $ref: "#/components/schemas/Error" + class_type: + type: string + description: The node's class type + dependent_outputs: + type: array + items: {} + + PromptInfo: + type: object + description: Summary of a queued or recently-executed prompt, as returned by the queue and history endpoints. + properties: + exec_info: + type: object + properties: + queue_remaining: + type: integer + description: Number of items remaining in the queue + + # ------------------------------------------------------------------- + # Queue + # ------------------------------------------------------------------- + QueueInfo: + type: object + description: Queue information with pending and running items + properties: + queue_running: + type: array + description: Currently running queue items + items: + type: array + description: | + Queue item tuple: [number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive] + items: {} + prefixItems: + - type: number + description: Priority number + - type: string + format: uuid + description: prompt_id + - type: object + description: prompt graph + additionalProperties: true + - type: object + description: extra_data + additionalProperties: true + - type: array + description: outputs_to_execute (list of output node IDs) + items: + type: string + - type: object + description: sensitive data (may be omitted) + additionalProperties: true + queue_pending: + type: array + description: Pending queue items (oldest first) + items: + type: array + description: | + Queue item tuple: [number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive] + items: {} + prefixItems: + - type: number + description: Priority number + - type: string + format: uuid + description: prompt_id + - type: object + description: prompt graph + additionalProperties: true + - type: object + description: extra_data + additionalProperties: true + - type: array + description: outputs_to_execute (list of output node IDs) + items: + type: string + - type: object + description: sensitive data (may be omitted) + additionalProperties: true + + QueueManageRequest: + type: object + description: Request to clear or delete from queue + properties: + clear: + type: boolean + description: If true, clear all pending items + delete: + type: array + items: + type: string + description: Array of prompt IDs to delete from queue + + # ------------------------------------------------------------------- + # History + # ------------------------------------------------------------------- + HistoryEntry: + type: object + description: A single execution history entry + properties: + prompt: + type: array + description: | + Prompt tuple: [number, prompt_id, prompt_graph, extra_data, output_node_ids] + items: {} + outputs: + type: object + description: Output data from execution keyed by node ID + additionalProperties: true + status: + type: object + description: Execution status (status_str, completed, messages, etc.) + additionalProperties: true + meta: + type: object + description: Metadata about the execution and nodes + additionalProperties: true + + HistoryManageRequest: + type: object + description: Request to clear or delete history entries + properties: + clear: + type: boolean + description: If true, clear all history + delete: + type: array + items: + type: string + description: Array of prompt IDs to delete from history + + # ------------------------------------------------------------------- + # Jobs + # ------------------------------------------------------------------- + JobEntry: + type: object + description: Lightweight job data for list views + required: + - id + - status + properties: + id: + type: string + format: uuid + description: Unique job identifier (same as prompt_id) + status: + type: string + description: Current job status + create_time: + type: number + description: Job creation timestamp + execution_start_time: + type: number + description: Workflow execution start timestamp + execution_end_time: + type: number + description: Workflow execution end timestamp + preview_output: + type: object + additionalProperties: true + description: Primary preview output + outputs_count: + type: integer + description: Total number of output files + + JobDetailResponse: + type: object + description: Full job details including workflow and outputs + required: + - id + - status + properties: + id: + type: string + format: uuid + status: + type: string + workflow: + type: object + additionalProperties: true + description: Full ComfyUI workflow + outputs: + type: object + additionalProperties: true + description: Full outputs object from execution + execution_error: + $ref: "#/components/schemas/ExecutionError" + create_time: + type: number + update_time: + type: number + execution_start_time: + type: number + execution_end_time: + type: number + preview_output: + type: object + additionalProperties: true + outputs_count: + type: integer + execution_status: + type: object + additionalProperties: true + execution_meta: + type: object + additionalProperties: true + + ExecutionError: + type: object + description: Detailed execution error from ComfyUI + properties: + node_id: + type: string + description: ID of the node that failed + node_type: + type: string + description: Type name of the node + exception_message: + type: string + description: Human-readable error message + exception_type: + type: string + description: Python exception type + traceback: + type: array + items: + type: string + description: Traceback lines + current_inputs: + type: object + additionalProperties: true + current_outputs: + type: object + additionalProperties: true + + PaginationInfo: + type: object + description: Pagination metadata returned alongside list responses. + properties: + offset: + type: integer + limit: + type: integer + total: + type: integer + has_more: + type: boolean + + # ------------------------------------------------------------------- + # Upload / View + # ------------------------------------------------------------------- + UploadResult: + type: object + description: Response body returned by the image/mask upload endpoints, describing where the uploaded file now lives. + properties: + name: + type: string + description: Saved filename (may be renamed to avoid collisions) + subfolder: + type: string + description: Subfolder the file was saved to + type: + type: string + description: Directory type (input, temp) + + # ------------------------------------------------------------------- + # System + # ------------------------------------------------------------------- + DeviceStats: + type: object + description: GPU/compute device statistics + required: + - name + - type + - index + properties: + name: + type: string + description: Device name + type: + type: string + description: Device type (cuda, mps, cpu, etc.) + index: + type: number + description: Device index + vram_total: + type: number + description: Total VRAM in bytes + vram_free: + type: number + description: Free VRAM in bytes + torch_vram_total: + type: number + description: Total PyTorch-managed VRAM in bytes + torch_vram_free: + type: number + description: Free PyTorch-managed VRAM in bytes + + SystemStatsResponse: + type: object + description: Hardware, VRAM, Python, and ComfyUI version information for the running process. + required: + - system + - devices + properties: + system: + type: object + required: + - os + - python_version + - embedded_python + - comfyui_version + - pytorch_version + - argv + - ram_total + - ram_free + properties: + os: + type: string + description: Operating system + python_version: + type: string + description: Python version + embedded_python: + type: boolean + description: Whether using embedded Python + comfyui_version: + type: string + description: ComfyUI version string + pytorch_version: + type: string + description: PyTorch version + required_frontend_version: + type: string + description: Required frontend version + argv: + type: array + items: + type: string + description: Command line arguments + ram_total: + type: number + description: Total RAM in bytes + ram_free: + type: number + description: Free RAM in bytes + installed_templates_version: + type: string + nullable: true + description: Version of the currently installed workflow templates + required_templates_version: + type: string + nullable: true + description: Minimum required workflow templates version for this ComfyUI build + devices: + type: array + items: + $ref: "#/components/schemas/DeviceStats" + + # ------------------------------------------------------------------- + # Node / Object Info + # ------------------------------------------------------------------- + NodeInfo: + type: object + description: 'Definition of a registered node class: its inputs, outputs, category, and display metadata.' + properties: + input: + type: object + description: Input specifications (required and optional groups) + additionalProperties: true + input_order: + type: object + description: Ordered input names per group + additionalProperties: + type: array + items: + type: string + output: + type: array + items: + type: string + description: Output type names + output_is_list: + type: array + items: + type: boolean + description: Whether each output is a list + output_name: + type: array + items: + type: string + description: Display names of outputs + name: + type: string + description: Internal class name + display_name: + type: string + description: Human-readable display name + description: + type: string + description: Node description + python_module: + type: string + description: Python module implementing the node + category: + type: string + description: Node category path + output_node: + type: boolean + description: Whether this is an output node + output_tooltips: + type: array + items: + type: string + description: Tooltips for each output + deprecated: + type: boolean + description: Whether the node is deprecated + experimental: + type: boolean + description: Whether the node is experimental + api_node: + type: boolean + description: Whether this is an API node + is_input_list: + type: boolean + description: Whether the node accepts list inputs + dev_only: + type: boolean + description: Whether the node is developer-only (hidden in production UI) + has_intermediate_output: + type: boolean + description: Whether the node emits intermediate output during execution + search_aliases: + type: array + items: + type: string + description: Alternative search terms for finding this node + essentials_category: + type: string + description: Category override used by the essentials pack + + # ------------------------------------------------------------------- + # Models + # ------------------------------------------------------------------- + ModelFolder: + type: object + description: A configured model folder and the list of disk paths it resolves to. + required: + - name + - folders + properties: + name: + type: string + description: Model folder type name (e.g. "checkpoints") + folders: + type: array + items: + type: string + description: Filesystem paths for this model type + + ModelFile: + type: object + description: A single model file in a folder, with filesystem metadata. + required: + - name + - pathIndex + properties: + name: + type: string + description: Model filename + pathIndex: + type: integer + description: Index into the folder's paths array + modified: + type: number + description: File modification timestamp + created: + type: number + description: File creation timestamp + size: + type: integer + format: int64 + description: File size in bytes + + # ------------------------------------------------------------------- + # Subgraphs + # ------------------------------------------------------------------- + GlobalSubgraphInfo: + type: object + description: Metadata for a global subgraph blueprint (without full data) + required: + - source + - name + - info + properties: + source: + type: string + description: Source type ("templates" or "custom_node") + name: + type: string + description: Display name of the subgraph blueprint + info: + type: object + description: Additional information about the subgraph + required: + - node_pack + properties: + node_pack: + type: string + description: The node pack/module providing this subgraph + data: + type: string + description: The full subgraph JSON data (may be empty in list view) + + GlobalSubgraphData: + type: object + description: Full data for a global subgraph blueprint + required: + - source + - name + - info + - data + properties: + source: + type: string + description: Source type ("templates" or "custom_node") + name: + type: string + description: Display name of the subgraph blueprint + info: + type: object + description: Additional information about the subgraph + required: + - node_pack + properties: + node_pack: + type: string + description: The node pack/module providing this subgraph + data: + type: string + description: The full subgraph JSON data as a string + + # ------------------------------------------------------------------- + # Userdata + # ------------------------------------------------------------------- + UserDataResponse: + description: | + Response body for the POST endpoints `/api/userdata/{file}` and + `/api/userdata/{file}/move/{dest}`. Returns a single item whose + shape depends on the `full_info` query parameter. + x-variant-selector: + full_info=true: file-info object (`GetUserDataResponseFullFile`) + default: relative path string + oneOf: + - $ref: "#/components/schemas/GetUserDataResponseFullFile" + - type: string + description: Relative path of the written or moved file. Returned when `full_info` is absent or false. + + ListUserdataResponse: + description: | + Response body for `GET /api/userdata`. The array item shape is + determined by the `full_info` and `split` query parameters. + x-variant-selector: + full_info=true: array of file-info objects (`GetUserDataResponseFullFile`) + split=true: array of `[relative_path, ...path_components]` arrays + default: array of relative path strings + oneOf: + - type: array + items: + $ref: "#/components/schemas/GetUserDataResponseFullFile" + description: Returned when `full_info=true`. + - type: array + items: + type: array + items: + type: string + minItems: 2 + description: | + Returned when `split=true` and `full_info=false`. Each inner + array is `[relative_path, ...path_components]`. + - type: array + items: + type: string + description: Default shape — array of file paths relative to the user data root. + + GetUserDataResponseFullFile: + type: object + description: A single entry in a full-info user data listing. + properties: + path: + type: string + description: File name or path relative to the user directory + created: + type: number + description: Unix timestamp of file creation + size: + type: integer + description: File size in bytes + modified: + type: integer + format: int64 + description: Unix timestamp of last modification in milliseconds + + # ------------------------------------------------------------------- + # Assets + # ------------------------------------------------------------------- + Asset: + type: object + description: A registered asset — an input/output file tracked in the asset database with content hash and metadata. + required: + - id + - name + - size + - created_at + - updated_at + properties: + id: + type: string + format: uuid + description: Unique identifier for the asset + name: + type: string + description: Name of the asset file + asset_hash: + type: string + description: Blake3 hash of the asset content + pattern: "^blake3:[a-f0-9]{64}$" + size: + type: integer + format: int64 + description: Size of the asset in bytes + mime_type: + type: string + description: MIME type of the asset + tags: + type: array + items: + type: string + description: Tags associated with the asset + user_metadata: + type: object + description: Custom user metadata + additionalProperties: true + metadata: + type: object + description: System-managed metadata (read-only) + additionalProperties: true + readOnly: true + preview_url: + type: string + format: uri + description: URL for asset preview/thumbnail + preview_id: + type: string + format: uuid + description: ID of the preview asset if available + prompt_id: + type: string + format: uuid + description: ID of the prompt that created this asset + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + last_access_time: + type: string + format: date-time + is_immutable: + type: boolean + description: Whether this asset is immutable + + AssetCreated: + description: Response body returned after successfully registering a new asset. + allOf: + - $ref: "#/components/schemas/Asset" + - type: object + required: + - created_new + properties: + created_new: + type: boolean + description: Whether this was a new creation (true) or returned existing (false) + + AssetUpdated: + type: object + description: Response body returned after updating an asset's metadata. + required: + - id + - updated_at + properties: + id: + type: string + format: uuid + name: + type: string + asset_hash: + type: string + pattern: "^blake3:[a-f0-9]{64}$" + tags: + type: array + items: + type: string + mime_type: + type: string + user_metadata: + type: object + additionalProperties: true + updated_at: + type: string + format: date-time + + ListAssetsResponse: + type: object + description: Paginated list of assets. + required: + - assets + - total + - has_more + properties: + assets: + type: array + items: + $ref: "#/components/schemas/Asset" + total: + type: integer + has_more: + type: boolean + + TagInfo: + type: object + description: A tag known to the asset database, with the number of assets bearing it. + required: + - name + - count + properties: + name: + type: string + count: + type: integer + + ListTagsResponse: + type: object + description: Flat list of all tags, with counts. + required: + - tags + - total + - has_more + properties: + tags: + type: array + items: + $ref: "#/components/schemas/TagInfo" + total: + type: integer + has_more: + type: boolean + + AssetTagHistogramResponse: + type: object + description: Tags that would refine a filtered asset query, with the count of assets each tag would additionally select. + required: + - tag_counts + properties: + tag_counts: + type: object + additionalProperties: + type: integer + description: Map of tag names to occurrence counts + + TagsModificationResponse: + type: object + description: Response body returned after adding or removing tags on an asset. + required: + - total_tags + properties: + added: + type: array + items: + type: string + description: Tags successfully added + removed: + type: array + items: + type: string + description: Tags successfully removed + already_present: + type: array + items: + type: string + description: Tags already present (for add) + not_present: + type: array + items: + type: string + description: Tags not present (for remove) + total_tags: + type: array + items: + type: string + description: All tags on the asset after the operation + + # ------------------------------------------------------------------- + # Result / Output types + # ------------------------------------------------------------------- + ResultItem: + type: object + description: A single output file reference + properties: + filename: + type: string + subfolder: + type: string + type: + type: string + enum: [input, output, temp] + display_name: + type: string + + NodeOutputs: + type: object + description: | + Outputs from a single node execution. Known keys are listed below, + but custom nodes may add arbitrary keys (additionalProperties). + properties: + images: + type: array + items: + $ref: "#/components/schemas/ResultItem" + audio: + type: array + items: + $ref: "#/components/schemas/ResultItem" + video: + type: array + items: + $ref: "#/components/schemas/ResultItem" + animated: + type: array + items: + type: boolean + text: + oneOf: + - type: string + - type: array + items: + type: string + additionalProperties: true + + TerminalSize: + type: object + description: Terminal dimensions + properties: + cols: + type: number + row: + type: number + + LogEntry: + type: object + description: A single log entry + properties: + t: + type: string + description: Timestamp + m: + type: string + description: Log message + + StatusWsMessageStatus: + type: object + description: Inner payload of a `status` WebSocket message, describing the execution queue state. + properties: + exec_info: + type: object + required: + - queue_remaining + properties: + queue_remaining: + type: integer + + StatusWsMessage: + type: object + description: Initial status message sent on connect + queue status updates + properties: + status: + $ref: "#/components/schemas/StatusWsMessageStatus" + sid: + type: string + description: Session ID assigned by the server + + ProgressWsMessage: + type: object + description: Node execution progress (step N of M) + required: + - value + - max + - prompt_id + - node + properties: + value: + type: integer + description: Current step + max: + type: integer + description: Total steps + prompt_id: + type: string + node: + type: string + description: Node ID currently executing + + ProgressTextWsMessage: + type: object + description: Text-based progress update from a node + properties: + nodeId: + type: string + text: + type: string + prompt_id: + type: string + + NodeProgressState: + type: object + description: Progress state for a single node + properties: + value: + type: number + max: + type: number + state: + type: string + enum: [pending, running, finished, error] + node_id: + type: string + prompt_id: + type: string + display_node_id: + type: string + parent_node_id: + type: string + real_node_id: + type: string + + ProgressStateWsMessage: + type: object + description: Bulk progress state for all nodes in a prompt + required: + - prompt_id + - nodes + properties: + prompt_id: + type: string + nodes: + type: object + description: Map of node ID to progress state + additionalProperties: + $ref: "#/components/schemas/NodeProgressState" + + ExecutingWsMessage: + type: object + description: Fired when a node begins execution + required: + - node + - display_node + - prompt_id + properties: + node: + type: string + description: Node ID + display_node: + type: string + description: Display node ID (may differ for subgraphs) + prompt_id: + type: string + + ExecutedWsMessage: + type: object + description: Fired when a node completes execution with output + required: + - node + - display_node + - prompt_id + - output + properties: + node: + type: string + display_node: + type: string + prompt_id: + type: string + output: + $ref: "#/components/schemas/NodeOutputs" + merge: + type: boolean + description: Whether to merge with existing output + + ExecutionWsMessageBase: + type: object + description: Base fields for execution lifecycle messages + required: + - prompt_id + - timestamp + properties: + prompt_id: + type: string + timestamp: + type: integer + description: Unix timestamp in milliseconds + + ExecutionStartWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + description: Fired when prompt execution begins + + ExecutionSuccessWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + description: Fired when prompt execution completes successfully + + ExecutionCachedWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + - type: object + properties: + nodes: + type: array + items: + type: string + description: List of node IDs that were cached + description: Fired when nodes are served from cache + + ExecutionInterruptedWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + - type: object + properties: + node_id: + type: string + node_type: + type: string + executed: + type: array + items: + type: string + description: Node IDs that completed before interruption + description: Fired when execution is interrupted by user + + ExecutionErrorWsMessage: + allOf: + - $ref: "#/components/schemas/ExecutionWsMessageBase" + - type: object + properties: + node_id: + type: string + node_type: + type: string + executed: + type: array + items: + type: string + exception_message: + type: string + exception_type: + type: string + traceback: + type: array + items: + type: string + current_inputs: {} + current_outputs: {} + description: Fired when a node throws an exception during execution + + LogsWsMessage: + type: object + description: Streaming log entries from the server + properties: + size: + $ref: "#/components/schemas/TerminalSize" + entries: + type: array + items: + $ref: "#/components/schemas/LogEntry" + + NotificationWsMessage: + type: object + description: Server notification (e.g. model download complete) + properties: + value: + type: string + id: + type: string + + FeatureFlagsWsMessage: + type: object + description: Feature flags sent on connect + additionalProperties: true + + AssetDownloadWsMessage: + type: object + description: Asset download progress + required: + - task_id + - asset_name + - bytes_total + - bytes_downloaded + - progress + - status + properties: + task_id: + type: string + asset_name: + type: string + bytes_total: + type: number + bytes_downloaded: + type: number + progress: + type: number + description: 0.0 to 1.0 + status: + type: string + enum: [created, running, completed, failed] + asset_id: + type: string + error: + type: string + + AssetExportWsMessage: + type: object + description: Bulk asset export progress + required: + - task_id + - assets_total + - assets_attempted + - assets_failed + - bytes_total + - bytes_processed + - progress + - status + properties: + task_id: + type: string + export_name: + type: string + assets_total: + type: number + assets_attempted: + type: number + assets_failed: + type: number + bytes_total: + type: number + bytes_processed: + type: number + progress: + type: number + description: 0.0 to 1.0 + status: + type: string + enum: [created, running, completed, failed] + error: + type: string diff --git a/requirements.txt b/requirements.txt index 63d7c41cf..6c7457e03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -comfyui-frontend-package==1.42.11 -comfyui-workflow-templates==0.9.57 -comfyui-embedded-docs==0.4.3 +comfyui-frontend-package==1.42.15 +comfyui-workflow-templates==0.9.62 +comfyui-embedded-docs==0.4.4 torch torchsde torchvision @@ -23,7 +23,7 @@ SQLAlchemy>=2.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo>=0.2.12 +comfy-aimdo==0.2.14 requests simpleeval>=1.0.0 blake3 diff --git a/utils/install_util.py b/utils/install_util.py index 34489aec5..fdba23a8f 100644 --- a/utils/install_util.py +++ b/utils/install_util.py @@ -39,7 +39,7 @@ def get_required_packages_versions(): if len(s) == 2: version_str = s[-1] if not is_valid_version(version_str): - logging.error(f"Invalid version format in requirements.txt: {version_str}") + logging.debug(f"Invalid version format for {s[0]} in requirements.txt: {version_str}") continue out[s[0]] = version_str return out.copy()